test_batch.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # -*- coding: utf-8 -*-
  2. import json
  3. import tensorflow as tf
  4. import numpy as np
  5. import time
  6. from PIL import Image
  7. import random
  8. import os
  9. from cnnlib.network import CNN
  10. from log_ware import LogWare
  11. logger = LogWare().get_logger()
  12. class TestError(Exception):
  13. pass
  14. class TestBatch(CNN):
  15. def __init__(self, img_path, char_set, model_save_dir, total):
  16. # 模型路径
  17. self.model_save_dir = model_save_dir
  18. # 打乱文件顺序
  19. self.img_path = img_path
  20. self.img_list = os.listdir(img_path)
  21. random.seed(time.time())
  22. random.shuffle(self.img_list)
  23. # 获得图片宽高和字符长度基本信息
  24. label, captcha_array = self.gen_captcha_text_image()
  25. captcha_shape = captcha_array.shape
  26. captcha_shape_len = len(captcha_shape)
  27. if captcha_shape_len == 3:
  28. image_height, image_width, channel = captcha_shape
  29. self.channel = channel
  30. elif captcha_shape_len == 2:
  31. image_height, image_width = captcha_shape
  32. else:
  33. raise TestError("图片转换为矩阵时出错,请检查图片格式")
  34. # 初始化变量
  35. super(TestBatch, self).__init__(image_height, image_width, len(label), char_set, model_save_dir)
  36. self.total = total
  37. # 相关信息打印
  38. logger.debug("-->图片尺寸: %s X %s", image_height, image_width)
  39. logger.debug("-->验证码长度: %s", self.max_captcha)
  40. logger.debug("-->验证码共%d类 %s", self.char_set_len, char_set)
  41. logger.debug("-->使用测试集为 %s", img_path)
  42. def gen_captcha_text_image(self):
  43. """
  44. 返回一个验证码的array形式和对应的字符串标签
  45. :return:tuple (str, numpy.array)
  46. """
  47. img_name = random.choice(self.img_list)
  48. # 标签
  49. label = img_name.split("_")[0]
  50. # 文件
  51. img_file = os.path.join(self.img_path, img_name)
  52. captcha_image = Image.open(img_file)
  53. captcha_array = np.array(captcha_image) # 向量化
  54. return label, captcha_array
  55. def test_batch(self):
  56. y_predict = self.model()
  57. total = self.total
  58. right = 0
  59. saver = tf.train.Saver()
  60. with tf.Session() as sess:
  61. saver.restore(sess, self.model_save_dir)
  62. s = time.time()
  63. for i in range(total):
  64. # test_text, test_image = gen_special_num_image(i)
  65. test_text, test_image = self.gen_captcha_text_image() # 随机
  66. test_image = self.convert2gray(test_image)
  67. test_image = test_image.flatten() / 255
  68. predict = tf.argmax(tf.reshape(y_predict, [-1, self.max_captcha, self.char_set_len]), 2)
  69. text_list = sess.run(predict, feed_dict={self.X: [test_image], self.keep_prob: 1.})
  70. predict_text = text_list[0].tolist()
  71. p_text = ""
  72. p_info = ""
  73. for p in predict_text:
  74. p_text += str(self.char_set[p])
  75. if test_text == p_text:
  76. p_info = "匹配"
  77. right += 1
  78. else:
  79. p_info = "不匹配"
  80. pass
  81. logger.debug("origin: %s predict: %s %s", test_text, p_text, p_info)
  82. e = time.time()
  83. rate = str(right / total * 100) + "%"
  84. logger.debug("测试结果: %d/%d", right, total)
  85. logger.debug("%d个样本识别耗时%f秒,准确率%s", total, e - s, rate)
  86. def main():
  87. with open("conf/sample_config.json", "r") as f:
  88. sample_conf = json.load(f)
  89. test_image_dir = sample_conf["test_image_dir"]
  90. model_save_dir = sample_conf["model_save_dir"]
  91. use_labels_json_file = sample_conf['use_labels_json_file']
  92. if use_labels_json_file:
  93. with open("tools/labels.json", "r") as f:
  94. char_set = f.read().strip()
  95. else:
  96. char_set = sample_conf["char_set"]
  97. total = 15
  98. tb = TestBatch(test_image_dir, char_set, model_save_dir, total)
  99. tb.test_batch()
  100. if __name__ == '__main__':
  101. main()