123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288 |
- # -*- coding: utf-8 -*-
- import json
- import tensorflow as tf
- import numpy as np
- import matplotlib.pyplot as plt
- import time
- from PIL import Image
- import random
- import os
- from cnnlib.network import CNN
- from log_ware import LogWare
- log_inst = LogWare()
- logger = log_inst.get_logger()
- class TrainError(Exception):
- pass
- class TrainModel(CNN):
- def __init__(self, train_img_path, verify_img_path, char_set, model_save_dir, cycle_stop, acc_stop, cycle_save,
- image_suffix, train_batch_size, test_batch_size, verify=False):
- # 训练相关参数
- self.cycle_stop = cycle_stop
- self.acc_stop = acc_stop
- self.cycle_save = cycle_save
- self.train_batch_size = train_batch_size
- self.test_batch_size = test_batch_size
- self.image_suffix = image_suffix
- char_set = [str(i) for i in char_set]
- # 打乱文件顺序+校验图片格式
- self.train_img_path = train_img_path
- self.train_images_list = os.listdir(train_img_path)
- # 校验格式
- if verify:
- self.confirm_image_suffix()
- # 打乱文件顺序
- random.seed(time.time())
- random.shuffle(self.train_images_list)
- # 验证集文件
- self.verify_img_path = verify_img_path
- self.verify_images_list = os.listdir(verify_img_path)
- # 获得图片宽高和字符长度基本信息
- label, captcha_array = self.gen_captcha_text_image(train_img_path, self.train_images_list[0])
- captcha_shape = captcha_array.shape
- captcha_shape_len = len(captcha_shape)
- if captcha_shape_len == 3:
- image_height, image_width, channel = captcha_shape
- self.channel = channel
- elif captcha_shape_len == 2:
- image_height, image_width = captcha_shape
- else:
- raise TrainError("图片转换为矩阵时出错,请检查图片格式")
- # 初始化变量
- super(TrainModel, self).__init__(image_height, image_width, len(label), char_set, model_save_dir)
- # 相关信息打印
- logger.debug("-->图片尺寸: %s X %s", image_height, image_width)
- logger.debug("-->验证码长度: %s", self.max_captcha)
- logger.debug("-->验证码共%s类 %s", self.char_set_len, char_set)
- logger.debug("-->使用测试集为 %s", train_img_path)
- logger.debug("-->使验证集为 %s", verify_img_path)
- # test model input and output
- logger.debug(">>> Start model test")
- batch_x, batch_y = self.get_batch(0, size=10)
- logger.debug(">>> input batch images shape: %s", batch_x.shape)
- logger.debug(">>> input batch labels shape: %s", batch_y.shape)
- @staticmethod
- def gen_captcha_text_image(img_path, img_name):
- """
- 返回一个验证码的array形式和对应的字符串标签
- :return:tuple (str, numpy.array)
- """
- # 标签
- label = img_name.split("_")[0]
- # 文件
- img_file = os.path.join(img_path, img_name)
- captcha_image = Image.open(img_file)
- captcha_array = np.array(captcha_image) # 向量化
- return label, captcha_array
- def get_batch(self, n, size=128):
- batch_x = np.zeros([size, self.image_height * self.image_width]) # 初始化
- batch_y = np.zeros([size, self.max_captcha * self.char_set_len]) # 初始化
- max_batch = int(len(self.train_images_list) / size)
- # print("max_batch: %s", max_batch)
- if max_batch - 1 < 0:
- raise TrainError("训练集图片数量需要大于每批次训练的图片数量")
- if n > max_batch - 1:
- n = n % max_batch
- s = n * size
- e = (n + 1) * size
- this_batch = self.train_images_list[s:e]
- # print("%s:%s", s, e))
- for i, img_name in enumerate(this_batch):
- label, image_array = self.gen_captcha_text_image(self.train_img_path, img_name)
- image_array = self.convert2gray(image_array) # 灰度化图片
- batch_x[i, :] = image_array.flatten() / 255 # flatten 转为一维
- batch_y[i, :] = self.text2vec(label) # 生成 oneHot
- return batch_x, batch_y
- def get_verify_batch(self, size=100):
- batch_x = np.zeros([size, self.image_height * self.image_width]) # 初始化
- batch_y = np.zeros([size, self.max_captcha * self.char_set_len]) # 初始化
- verify_images = []
- for i in range(size):
- verify_images.append(random.choice(self.verify_images_list))
- for i, img_name in enumerate(verify_images):
- label, image_array = self.gen_captcha_text_image(self.verify_img_path, img_name)
- image_array = self.convert2gray(image_array) # 灰度化图片
- batch_x[i, :] = image_array.flatten() / 255 # flatten 转为一维
- batch_y[i, :] = self.text2vec(label) # 生成 oneHot
- return batch_x, batch_y
- def confirm_image_suffix(self):
- # 在训练前校验所有文件格式
- print("开始校验所有图片后缀")
- for index, img_name in enumerate(self.train_images_list):
- print("%s image pass", index)
- if not img_name.endswith(self.image_suffix):
- raise TrainError('confirm images suffix:you request [.{}] file but get file [{}]'
- .format(self.image_suffix, img_name))
- logger.debug("所有图片格式校验通过")
- def train_cnn(self):
- y_predict = self.model()
- logger.debug(">>> input batch predict shape: %s", y_predict.shape)
- logger.debug(">>> End model test")
- # 计算概率 损失
- with tf.name_scope('cost'):
- cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_predict, labels=self.Y))
- # 梯度下降
- with tf.name_scope('train'):
- optimizer = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost)
- # 计算准确率
- predict = tf.reshape(y_predict, [-1, self.max_captcha, self.char_set_len]) # 预测结果
- max_idx_p = tf.argmax(predict, 2) # 预测结果
- max_idx_l = tf.argmax(tf.reshape(self.Y, [-1, self.max_captcha, self.char_set_len]), 2) # 标签
- # 计算准确率
- correct_pred = tf.equal(max_idx_p, max_idx_l)
- with tf.name_scope('char_acc'):
- accuracy_char_count = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
- with tf.name_scope('image_acc'):
- accuracy_image_count = tf.reduce_mean(tf.reduce_min(tf.cast(correct_pred, tf.float32), axis=1))
- # 模型保存对象
- saver = tf.train.Saver()
- with tf.Session() as sess:
- init = tf.global_variables_initializer()
- sess.run(init)
- # 恢复模型
- if os.path.exists(self.model_save_dir):
- try:
- saver.restore(sess, self.model_save_dir)
- # 判断捕获model文件夹中没有模型文件的错误
- except ValueError:
- logger.debug("model文件夹为空,将创建新模型")
- else:
- pass
- # 写入日志
- temp_log_dir = log_inst.get_log_dir()
- sess_log_dir = os.path.join(temp_log_dir, 'train_sess')
- # 是否存在目录,不存在就创建
- mkdir_with_lambda = lambda x: os.makedirs(x) if not os.path.exists(x) else True
- mkdir_with_lambda(sess_log_dir)
- tf.summary.FileWriter(sess_log_dir, sess.graph)
- step = 1
- for i in range(self.cycle_stop):
- batch_x, batch_y = self.get_batch(i, size=self.train_batch_size)
- # 梯度下降训练
- _, cost_ = sess.run([optimizer, cost],
- feed_dict={self.X: batch_x, self.Y: batch_y, self.keep_prob: 0.75})
- if step % 10 == 0:
- # 基于训练集的测试
- batch_x_test, batch_y_test = self.get_batch(i, size=self.train_batch_size)
- acc_char = sess.run(accuracy_char_count,
- feed_dict={self.X: batch_x_test, self.Y: batch_y_test, self.keep_prob: 1.})
- acc_image = sess.run(accuracy_image_count,
- feed_dict={self.X: batch_x_test, self.Y: batch_y_test, self.keep_prob: 1.})
- print("第{}次训练 >>> ".format(step))
- print("[训练集] 字符准确率为 {:.5f} 图片准确率为 {:.5f} >>> loss {:.10f}".format(acc_char, acc_image, cost_))
- # with open("loss_train.csv", "a+") as f:
- # f.write("{},{},{},{}\n".format(step, acc_char, acc_image, cost_))
- # 基于验证集的测试
- batch_x_verify, batch_y_verify = self.get_verify_batch(size=self.test_batch_size)
- acc_char = sess.run(accuracy_char_count,
- feed_dict={self.X: batch_x_verify, self.Y: batch_y_verify, self.keep_prob: 1.})
- acc_image = sess.run(accuracy_image_count,
- feed_dict={self.X: batch_x_verify, self.Y: batch_y_verify, self.keep_prob: 1.})
- print("[验证集] 字符准确率为 {:.5f} 图片准确率为 {:.5f} >>> loss {:.10f}".format(acc_char, acc_image, cost_))
- # with open("loss_test.csv", "a+") as f:
- # f.write("{}, {},{},{}\n".format(step, acc_char, acc_image, cost_))
- # 准确率达到99%后保存并停止
- if acc_image > self.acc_stop:
- saver.save(sess, self.model_save_dir)
- logger.debug("验证集准确率达到%s,保存模型成功", str(self.acc_stop * 100) + "%")
- break
- # 每训练500轮就保存一次
- if i % self.cycle_save == 0:
- saver.save(sess, self.model_save_dir)
- print("定时保存模型成功")
- step += 1
- saver.save(sess, self.model_save_dir)
- def recognize_captcha(self):
- label, captcha_array = self.gen_captcha_text_image(self.train_img_path, random.choice(self.train_images_list))
- f = plt.figure()
- ax = f.add_subplot(111)
- ax.text(0.1, 0.9, "origin:" + label, ha='center', va='center', transform=ax.transAxes)
- plt.imshow(captcha_array)
- # 预测图片
- image = self.convert2gray(captcha_array)
- image = image.flatten() / 255
- y_predict = self.model()
- saver = tf.train.Saver()
- with tf.Session() as sess:
- saver.restore(sess, self.model_save_dir)
- predict = tf.argmax(tf.reshape(y_predict, [-1, self.max_captcha, self.char_set_len]), 2)
- text_list = sess.run(predict, feed_dict={self.X: [image], self.keep_prob: 1.})
- predict_text = text_list[0].tolist()
- logger.debug("正确: %s 预测: %s", label, predict_text)
- # 显示图片和预测结果
- p_text = ""
- for p in predict_text:
- p_text += str(self.char_set[p])
- logger.debug("p_text: %s", p_text)
- plt.text(20, 1, 'predict:{}'.format(p_text))
- plt.show()
- def main():
- with open("conf/sample_config.json", "r") as f:
- sample_conf = json.load(f)
- train_image_dir = sample_conf["train_image_dir"]
- verify_image_dir = sample_conf["test_image_dir"]
- model_save_dir = sample_conf["model_save_dir"]
- cycle_stop = sample_conf["cycle_stop"]
- acc_stop = sample_conf["acc_stop"]
- cycle_save = sample_conf["cycle_save"]
- enable_gpu = sample_conf["enable_gpu"]
- image_suffix = sample_conf['image_suffix']
- use_labels_json_file = sample_conf['use_labels_json_file']
- train_batch_size = sample_conf['train_batch_size']
- test_batch_size = sample_conf['test_batch_size']
- if use_labels_json_file:
- with open("tools/labels.json", "r") as f:
- char_set = f.read().strip()
- else:
- char_set = sample_conf["char_set"]
- if not enable_gpu:
- # 设置以下环境变量可开启CPU识别
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
- tm = TrainModel(train_image_dir, verify_image_dir, char_set, model_save_dir, cycle_stop, acc_stop, cycle_save,
- image_suffix, train_batch_size, test_batch_size, verify=False)
- tm.train_cnn() # 开始训练模型
- # tm.recognize_captcha() # 识别图片示例
- if __name__ == '__main__':
- main()
|