recognition_object.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # -*- coding: utf-8 -*-
  2. """
  3. 识别图像的类,为了快速进行多次识别可以调用此类下面的方法:
  4. R = Recognizer(image_height, image_width, max_captcha)
  5. for i in range(10):
  6. r_img = Image.open(str(i) + ".jpg")
  7. t = R.rec_image(r_img)
  8. 简单的图片每张基本上可以达到毫秒级的识别速度
  9. """
  10. import tensorflow as tf
  11. import numpy as np
  12. from PIL import Image
  13. from cnnlib.network import CNN
  14. import json
  15. from log_ware import LogWare
  16. logger = LogWare().get_logger()
  17. class Recognizer(CNN):
  18. def __init__(self, image_height, image_width, max_captcha, char_set, model_save_dir):
  19. # 初始化变量
  20. super(Recognizer, self).__init__(image_height, image_width, max_captcha, char_set, model_save_dir)
  21. # 新建图和会话
  22. self.g = tf.Graph()
  23. self.sess = tf.Session(graph=self.g)
  24. # 使用指定的图和会话
  25. with self.g.as_default():
  26. # 迭代循环前,写出所有用到的张量的计算表达式,如果写在循环中,会发生内存泄漏,拖慢识别的速度
  27. # tf初始化占位符
  28. self.X = tf.placeholder(tf.float32, [None, self.image_height * self.image_width]) # 特征向量
  29. self.Y = tf.placeholder(tf.float32, [None, self.max_captcha * self.char_set_len]) # 标签
  30. self.keep_prob = tf.placeholder(tf.float32) # dropout值
  31. # 加载网络和模型参数
  32. self.y_predict = self.model()
  33. self.predict = tf.argmax(tf.reshape(self.y_predict, [-1, self.max_captcha, self.char_set_len]), 2)
  34. saver = tf.train.Saver()
  35. with self.sess.as_default() as sess:
  36. saver.restore(sess, self.model_save_dir)
  37. # def __del__(self):
  38. # self.sess.close()
  39. # logger.debug("session close")
  40. def rec_image(self, img):
  41. # 读取图片
  42. img_array = np.array(img)
  43. test_image = self.convert2gray(img_array)
  44. test_image = test_image.flatten() / 255
  45. # 使用指定的图和会话
  46. with self.g.as_default():
  47. with self.sess.as_default() as sess:
  48. text_list = sess.run(self.predict, feed_dict={self.X: [test_image], self.keep_prob: 1.})
  49. # 获取结果
  50. predict_text = text_list[0].tolist()
  51. p_text = ""
  52. for p in predict_text:
  53. p_text += str(self.char_set[p])
  54. # 返回识别结果
  55. return p_text
  56. def main():
  57. with open("conf/sample_config.json", "r", encoding="utf-8") as f:
  58. sample_conf = json.load(f)
  59. image_height = sample_conf["image_height"]
  60. image_width = sample_conf["image_width"]
  61. max_captcha = sample_conf["max_captcha"]
  62. char_set = sample_conf["char_set"]
  63. model_save_dir = sample_conf["model_save_dir"]
  64. R = Recognizer(image_height, image_width, max_captcha, char_set, model_save_dir)
  65. r_img = Image.open("./sample/test/2b3n_6915e26c67a52bc0e4e13d216eb62b37.jpg")
  66. t = R.rec_image(r_img)
  67. logger.debug(t)
  68. if __name__ == '__main__':
  69. main()