123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- # -*- coding: utf-8 -*-
- import json
- import tensorflow as tf
- import numpy as np
- import time
- from PIL import Image
- import random
- import os
- from cnnlib.network import CNN
- from log_ware import LogWare
- logger = LogWare().get_logger()
- class TestError(Exception):
- pass
- class TestBatch(CNN):
- def __init__(self, img_path, char_set, model_save_dir, total):
- # 模型路径
- self.model_save_dir = model_save_dir
- # 打乱文件顺序
- self.img_path = img_path
- self.img_list = os.listdir(img_path)
- random.seed(time.time())
- random.shuffle(self.img_list)
- # 获得图片宽高和字符长度基本信息
- label, captcha_array = self.gen_captcha_text_image()
- captcha_shape = captcha_array.shape
- captcha_shape_len = len(captcha_shape)
- if captcha_shape_len == 3:
- image_height, image_width, channel = captcha_shape
- self.channel = channel
- elif captcha_shape_len == 2:
- image_height, image_width = captcha_shape
- else:
- raise TestError("图片转换为矩阵时出错,请检查图片格式")
- # 初始化变量
- super(TestBatch, self).__init__(image_height, image_width, len(label), char_set, model_save_dir)
- self.total = total
- # 相关信息打印
- logger.debug("-->图片尺寸: %s X %s", image_height, image_width)
- logger.debug("-->验证码长度: %s", self.max_captcha)
- logger.debug("-->验证码共%d类 %s", self.char_set_len, char_set)
- logger.debug("-->使用测试集为 %s", img_path)
- def gen_captcha_text_image(self):
- """
- 返回一个验证码的array形式和对应的字符串标签
- :return:tuple (str, numpy.array)
- """
- img_name = random.choice(self.img_list)
- # 标签
- label = img_name.split("_")[0]
- # 文件
- img_file = os.path.join(self.img_path, img_name)
- captcha_image = Image.open(img_file)
- captcha_array = np.array(captcha_image) # 向量化
- return label, captcha_array
- def test_batch(self):
- y_predict = self.model()
- total = self.total
- right = 0
- saver = tf.train.Saver()
- with tf.Session() as sess:
- saver.restore(sess, self.model_save_dir)
- s = time.time()
- for i in range(total):
- # test_text, test_image = gen_special_num_image(i)
- test_text, test_image = self.gen_captcha_text_image() # 随机
- test_image = self.convert2gray(test_image)
- test_image = test_image.flatten() / 255
- predict = tf.argmax(tf.reshape(y_predict, [-1, self.max_captcha, self.char_set_len]), 2)
- text_list = sess.run(predict, feed_dict={self.X: [test_image], self.keep_prob: 1.})
- predict_text = text_list[0].tolist()
- p_text = ""
- p_info = ""
- for p in predict_text:
- p_text += str(self.char_set[p])
- if test_text == p_text:
- p_info = "匹配"
- right += 1
- else:
- p_info = "不匹配"
- pass
- logger.debug("origin: %s predict: %s %s", test_text, p_text, p_info)
- e = time.time()
- rate = str(right / total * 100) + "%"
- logger.debug("测试结果: %d/%d", right, total)
- logger.debug("%d个样本识别耗时%f秒,准确率%s", total, e - s, rate)
- def main():
- with open("conf/sample_config.json", "r") as f:
- sample_conf = json.load(f)
- test_image_dir = sample_conf["test_image_dir"]
- model_save_dir = sample_conf["model_save_dir"]
- use_labels_json_file = sample_conf['use_labels_json_file']
- if use_labels_json_file:
- with open("tools/labels.json", "r") as f:
- char_set = f.read().strip()
- else:
- char_set = sample_conf["char_set"]
- total = 15
- tb = TestBatch(test_image_dir, char_set, model_save_dir, total)
- tb.test_batch()
- if __name__ == '__main__':
- main()
|