def train(gpu_num=None, with_generator=False, show_info=True): print('network creating ... ', end='', flush=True) network = UNetPP(INPUT_IMAGE_SHAPE, start_filter=START_FILTER, depth=DEPTH, class_num=CLASS_NUM) print('... created') if show_info: network.plot_model_summary('../model_plot.png') network.show_model_summary() if isinstance(gpu_num, int): model = network.get_parallel_model(gpu_num, with_comple=True) else: model = network.get_model(with_comple=True) model_filename = os.path.join(DIR_MODEL, File_MODEL) callbacks = [ KC.TensorBoard() , HistoryCheckpoint(filepath='LearningCurve_{history}.png' , verbose=1 , period=10 ) , KC.ModelCheckpoint(filepath=model_filename , verbose=1 , save_weights_only=True #, save_best_only=True , period=10 ) ] print('data generator creating ... ', end='', flush=True) train_generator = DataGenerator(DIR_INPUTS, DIR_TEACHERS, INPUT_IMAGE_SHAPE) print('... created') if with_generator: train_data_num = train_generator.data_size() #valid_data_num = train_generator.valid_data_size() his = model.fit_generator(train_generator.generator(batch_size=BATCH_SIZE) , steps_per_epoch=math.ceil(train_data_num / BATCH_SIZE) , epochs=EPOCHS , verbose=1 , use_multiprocessing=True , callbacks=callbacks #, validation_data=valid_generator #, validation_steps=math.ceil(valid_data_num / BATCH_SIZE) ) else: print('data generateing ... ') #, end='', flush=True) inputs, teachers = train_generator.generate_data() print('... generated') history = model.fit(inputs, teachers, batch_size=BATCH_SIZE, epochs=EPOCHS , shuffle=True, verbose=1, callbacks=callbacks) print('model saveing ... ', end='', flush=True) model.save_weights(model_filename) print('... saved') print('learning_curve saveing ... ', end='', flush=True) save_learning_curve(history) print('... saved')
def train(gpu_num=None, with_generator=False, load_model=False, show_info=True): print('network creating ... ', end='', flush=True) network = M2Det(INPUT_IMAGE_SHAPE, BATCH_SIZE, class_num=CLASS_NUM) print('... created') if show_info: network.plot_model_summary('../model_plot.png') network.show_model_summary() if isinstance(gpu_num, int): model = network.get_parallel_model(gpu_num, with_compile=True) else: model = network.get_model(with_compile=True) model_filename = os.path.join(DIR_MODEL, FILE_MODEL) callbacks = [ KC.TensorBoard(), HistoryCheckpoint(filepath='LearningCurve_{history}.png', verbose=1, period=10), KC.ModelCheckpoint(filepath=model_filename, verbose=1, save_weights_only=True, save_best_only=True, period=10) ] if load_model: print('loading weghts ... ', end='', flush=True) model.load_weights(model_filename) print('... loaded') print('data generating ...', end='', flush=True) priors = network.get_prior_boxes() bbox_util = BBoxUtility(CLASS_NUM, priors) train_generator = DataGenerator(DIR_TRAIN_INPUTS, DIR_TRAIN_TEACHERS, bbox_util, INPUT_IMAGE_SHAPE, with_norm=WITH_NORM) valid_generator = DataGenerator(DIR_VALID_INPUTS, DIR_VALID_TEACHERS, bbox_util, INPUT_IMAGE_SHAPE, with_norm=WITH_NORM) print('... created') if with_generator: train_data_num = train_generator.data_size() valid_data_num = valid_generator.data_size() history = model.fit_generator( train_generator.generator(batch_size=BATCH_SIZE), steps_per_epoch=math.ceil((train_data_num / BATCH_SIZE) * 2), epochs=EPOCHS, verbose=1, use_multiprocessing=True, callbacks=callbacks, validation_data=valid_generator.generator(batch_size=BATCH_SIZE), validation_steps=math.ceil(valid_data_num / BATCH_SIZE)) else: print('data generateing ... ') #, end='', flush=True) train_inputs, train_teachers = train_generator.generate_data( batch_size=BATCH_SIZE) valid_data = valid_generator.generate_data(batch_size=BATCH_SIZE) print('... generated') history = model.fit(train_inputs, train_teachers, batch_size=BATCH_SIZE, epochs=EPOCHS, validation_data=valid_data, shuffle=True, verbose=1, callbacks=callbacks) print('model saveing ... ', end='', flush=True) model.save_weights(model_filename) print('... saved') print('learning_curve saveing ... ', end='', flush=True) save_learning_curve(history) print('... saved')
def train(gpu_num=None, with_generator=False, load_model=False, show_info=True): print('network creating ... ') #, end='', flush=True) network = UNet(INPUT_IMAGE_SHAPE, CLASS_NUM) print('... created') model = network.model() if show_info: model.summary() if isinstance(gpu_num, int): model = multi_gpu_model(model, gpus=gpu_num) model.compile(optimizer='adam', loss=DiceLossByClass(INPUT_IMAGE_SHAPE, CLASS_NUM).dice_coef_loss) model_filename = os.path.join(DIR_MODEL, File_MODEL) callbacks = [ KC.TensorBoard(), HistoryCheckpoint(filepath='LearningCurve_{history}.png', verbose=1, period=10), KC.ModelCheckpoint(filepath=model_filename, verbose=1, save_weights_only=True, save_best_only=True, period=10) ] if load_model: print('loading weghts ... ', end='', flush=True) model.load_weights(model_filename) print('... loaded') print('data generating ...', end='', flush=True) train_generator = DataGenerator(DIR_TRAIN_INPUTS, DIR_TRAIN_TEACHERS, INPUT_IMAGE_SHAPE, include_padding=(PADDING, PADDING)) valid_generator = DataGenerator(DIR_VALID_INPUTS, DIR_VALID_TEACHERS, INPUT_IMAGE_SHAPE, include_padding=(PADDING, PADDING)) print('... created') if with_generator: train_data_num = train_generator.data_size() history = model.fit_generator( train_generator.generator(batch_size=BATCH_SIZE), steps_per_epoch=math.ceil(train_data_num / BATCH_SIZE), epochs=EPOCHS, verbose=1, use_multiprocessing=True, callbacks=callbacks, validation_data=valid_generator.generator(batch_size=BATCH_SIZE), validation_steps=math.ceil(valid_data_num / BATCH_SIZE)) else: print('data generateing ... ') #, end='', flush=True) inputs, teachers = train_generator.generate_data() valid_data = valid_generator.generate_data() print('... generated') history = model.fit(inputs, teachers, batch_size=BATCH_SIZE, epochs=EPOCHS, validation_data=valid_data, shuffle=True, verbose=1, callbacks=callbacks) print('model saveing ... ', end='', flush=True) model.save_weights(model_filename) print('... saved') print('learning_curve saveing ... ', end='', flush=True) save_learning_curve(history) print('... saved')