def train(): os.makedirs(cfg.save_weights_path, exist_ok=True) #选择网络 train_model, eval_model = cfg.select_model #编译和打印模型 train_model.compile(optimizer=cfg.optimizer, loss=cfg.loss, loss_weights=cfg.loss_weighting, metrics=cfg.metrics) print_summary(model=train_model) #训练数据生成器G1 G1 = imageSegmentationGenerator(cfg.train_images, cfg.train_label, cfg.train_batch_size, cfg.n_classes, cfg.input_shape[0], cfg.input_shape[1], cfg.output_shape[0], cfg.output_shape[1]) #测试数据生成器G2 if cfg.validate: G2 = imageSegmentationGenerator(cfg.val_images, cfg.train_label, cfg.val_batch_size, cfg.n_classes, cfg.input_shape[0], cfg.input_shape[1], cfg.output_shape[0], cfg.output_shape[1]) #循环训练 for ep in range(cfg.epochs): #1、训练两种方式 if not cfg.validate: #只有G1 hisroy = train_model.fit_generator( G1, steps_per_epoch=cfg.train_steps_per_epoch, workers=cfg.workers, epochs=1, verbose=1, use_multiprocessing=cfg.use_multiprocessing, class_weight='auto') else: #有G1和G2 hisroy = train_model.fit_generator( G1, steps_per_epoch=cfg.train_steps_per_epoch, workers=cfg.workers, epochs=1, verbose=1, use_multiprocessing=cfg.use_multiprocessing, validation_data=G2, validation_steps=cfg.validate_steps_per_epoch) # 3、保存模型 if (ep % cfg.epochs_save) == (cfg.epochs_save - 1): print('saving model.{}.......'.format(ep)) save_weights_name = 'model.{}'.format(ep) save_weights_path = os.path.join(cfg.save_weights_path, save_weights_name) train_model.save_weights(save_weights_path)
def train(): model = UNet(cfg.input_shape) #编译和打印模型 model.compile(optimizer=cfg.optimizer, loss=cfg.loss, metrics=cfg.metrics) print_summary(model=model) #训练数据生成器G1 G1 = imageSegmentationGenerator(cfg.train_images, cfg.train_annotations, cfg.train_batch_size, cfg.n_classes, cfg.input_shape[0], cfg.input_shape[1], cfg.output_shape[0], cfg.output_shape[1]) #测试数据生成器G2 if cfg.validate: G2 = imageSegmentationGenerator(cfg.val_images, cfg.val_annotations, cfg.val_batch_size, cfg.n_classes, cfg.input_shape[0], cfg.input_shape[1], cfg.output_shape[0], cfg.output_shape[1]) #循环训练 save_index = 1 for ep in range(cfg.epochs): #1、训练两种方式 if not cfg.validate: #只有G1 hisroy = model.fit_generator( G1, steps_per_epoch=cfg.train_steps_per_epoch, workers=cfg.workers, epochs=1, verbose=1, use_multiprocessing=cfg.use_multiprocessing) else: #有G1和G2 hisroy = model.fit_generator( G1, steps_per_epoch=cfg.train_steps_per_epoch, workers=cfg.workers, epochs=1, verbose=1, use_multiprocessing=cfg.use_multiprocessing, validation_data=G2, validation_steps=cfg.validate_steps_per_epoch) # 2、保存模型 if save_index == cfg.epochs_save: save_index = 1 save_weights_name = 'model.{}'.format(ep) save_weights_path = os.path.join(cfg.save_weights_path, save_weights_name) model.save_weights(save_weights_path) save_index += 1
def use_generator_to_show(images_path, label_path, batch_size, n_classes, input_height, input_width, output_height, output_width): batch_size_n = 0 #取第batch_size_n张图片观察 plt.figure() colors = [(np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255)) for _ in range(n_classes)] #使用Generator,返回一个batch_size的 im_fn和seg_fn for im_fn, seg_vec in imageSegmentationGenerator(images_path, label_path, batch_size, n_classes, input_height, input_width, output_height, output_width): # 1、原图 print('return img shape: ', im_fn.shape) pics_group = split_batch_to_pic_list(im_fn) # batchsize切成图片列表 pic = pics_group[batch_size_n] # 取第batch_size_n张图片观察 # 2、label图 if use_binary_label: n_classes = 1 print('return label shape: ', seg_vec.shape) seg_vec = split_batch_to_pic_list(seg_vec) # batchsize切成图片列表 seg_vec = seg_vec[batch_size_n] # 取第batch_size_n张图片观察 seg_img = seg_vec_to_pic(seg_vec, pic.shape, colors, n_classes) # 5、显示img和label plt_imshow_two_pics(pic, seg_img) # 用plt显示 time.sleep(1)