verify_and_split_data.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. """
  2. 验证图片尺寸和分离测试集(5%)和训练集(95%)
  3. 初始化的时候使用,有新的图片后,可以把图片放在new目录里面使用。
  4. """
  5. import json
  6. from PIL import Image
  7. import random
  8. import os
  9. import shutil
  10. from log_ware import LogWare
  11. logger = LogWare().get_logger()
  12. def verify(origin_dir, real_width, real_height, image_suffix):
  13. """
  14. 校验图片大小
  15. :return:
  16. """
  17. if not os.path.exists(origin_dir):
  18. logger.debug("【警告】找不到目录%s,即将创建", origin_dir)
  19. os.makedirs(origin_dir)
  20. logger.debug("开始校验原始图片集")
  21. # 图片真实尺寸
  22. real_size = (real_width, real_height)
  23. # 图片名称列表和数量
  24. img_list = os.listdir(origin_dir)
  25. total_count = len(img_list)
  26. logger.debug("原始集共有图片: %d张", total_count)
  27. # 无效图片列表
  28. bad_img = []
  29. # 遍历所有图片进行验证
  30. for index, img_name in enumerate(img_list):
  31. file_path = os.path.join(origin_dir, img_name)
  32. # 过滤图片不正确的后缀
  33. if not img_name.endswith(image_suffix):
  34. bad_img.append((index, img_name, "文件后缀不正确"))
  35. continue
  36. # logger.debug("图片命名下划线位置:%s", img_name.find("_"))
  37. if (img_name.find("_") == -1):
  38. bad_img.append((index, img_name, "图片命名没有下划线,图片名称:" + img_name))
  39. logger.debug("图片命名没有下划线,图片名称:%s", img_name)
  40. # 过滤图片标签不标准的情况
  41. prefix, posfix = img_name.split("_")
  42. if prefix == "" or posfix == "":
  43. bad_img.append((index, img_name, "图片标签异常"))
  44. continue
  45. # 图片无法正常打开
  46. try:
  47. img = Image.open(file_path)
  48. except OSError:
  49. bad_img.append((index, img_name, "图片无法正常打开"))
  50. continue
  51. # 图片尺寸有异常
  52. if real_size == img.size:
  53. logger.debug("%d pass", index)
  54. else:
  55. bad_img.append((index, img_name, "图片尺寸异常为:{}".format(img.size)))
  56. logger.debug("====以下%d张图片有异常====", len(bad_img))
  57. if bad_img:
  58. for b in bad_img:
  59. logger.debug("[第%d张图片] [%d] [%d]", b[0], b[1], b[2])
  60. else:
  61. logger.debug("未发现异常(共 %d 张图片)", len(img_list))
  62. logger.debug("========end")
  63. return bad_img
  64. def split(origin_dir, train_dir, test_dir, bad_imgs):
  65. """
  66. 分离训练集和测试集
  67. :return:
  68. """
  69. if not os.path.exists(origin_dir):
  70. logger.debug("【警告】找不到目录%s,即将创建", origin_dir)
  71. os.makedirs(origin_dir)
  72. logger.debug("开始分离原始图片集为:测试集(5%)和训练集(95%)")
  73. # 图片名称列表和数量
  74. img_list = os.listdir(origin_dir)
  75. for img in bad_imgs:
  76. img_list.remove(img)
  77. total_count = len(img_list)
  78. logger.debug("共分配%d张图片到训练集和测试集,其中%d张为异常留在原始目录", total_count, len(bad_imgs))
  79. # 创建文件夹
  80. if not os.path.exists(train_dir):
  81. os.mkdir(train_dir)
  82. if not os.path.exists(test_dir):
  83. os.mkdir(test_dir)
  84. # 测试集
  85. test_count = int(total_count * 0.05)
  86. test_set = set()
  87. for i in range(test_count):
  88. while True:
  89. file_name = random.choice(img_list)
  90. if file_name in test_set:
  91. pass
  92. else:
  93. test_set.add(file_name)
  94. img_list.remove(file_name)
  95. break
  96. test_list = list(test_set)
  97. logger.debug("测试集数量为:%s", len(test_list))
  98. for file_name in test_list:
  99. src = os.path.join(origin_dir, file_name)
  100. dst = os.path.join(test_dir, file_name)
  101. shutil.move(src, dst)
  102. # 训练集
  103. train_list = img_list
  104. logger.debug("训练集数量为:%s", len(train_list))
  105. for file_name in train_list:
  106. src = os.path.join(origin_dir, file_name)
  107. dst = os.path.join(train_dir, file_name)
  108. shutil.move(src, dst)
  109. if os.listdir(origin_dir) == 0:
  110. logger.debug("migration done")
  111. def main():
  112. with open("conf/sample_config.json", "r") as f:
  113. sample_conf = json.load(f)
  114. # 图片路径
  115. origin_dir = sample_conf["origin_image_dir"]
  116. new_dir = sample_conf["new_image_dir"]
  117. train_dir = sample_conf["train_image_dir"]
  118. test_dir = sample_conf["test_image_dir"]
  119. # 图片尺寸
  120. real_width = sample_conf["image_width"]
  121. real_height = sample_conf["image_height"]
  122. # 图片后缀
  123. image_suffix = sample_conf["image_suffix"]
  124. for image_dir in [origin_dir, new_dir]:
  125. logger.debug(">>> 开始校验目录:[%s]", image_dir)
  126. bad_images_info = verify(image_dir, real_width, real_height, image_suffix)
  127. bad_imgs = []
  128. for info in bad_images_info:
  129. bad_imgs.append(info[1])
  130. split(image_dir, train_dir, test_dir, bad_imgs)
  131. if __name__ == '__main__':
  132. main()