def train(model: keras.models.Model, config: TrainingConfig): training_generator, validation_generator = config.get_generators() callback_list = list() if config.use_tensorboard: print("using tensorboard") tb_callback = callbacks.TensorBoard(log_dir=config.tensorboard_log_dir, write_graph=False, update_freq=5000 ) callback_list.append(tb_callback) if config.reduce_lr_on_plateau: print("reducing learning rate on plateau") lr_callback = callbacks.ReduceLROnPlateau( factor=config.reduce_lr_on_plateau_factor, patience=config.reduce_lr_on_plateau_patience, cooldown=config.reduce_lr_on_plateau_cooldown, min_delta=config.reduce_lr_on_plateau_delta ) callback_list.append(lr_callback) if config.save_colored_image_progress: print("saving progression every {} epochs".format(config.image_progression_period)) op_callback = OutputProgress(config.image_paths_to_save, config.dim_in, config.image_progression_log_dir, every_n_epochs=config.image_progression_period) callback_list.append(op_callback) if config.periodically_save_model: print("saving model every {} epcohs".format(config.periodically_save_model_period)) p_save_callback = callbacks.ModelCheckpoint(config.periodically_save_model_path, period=config.periodically_save_model_period) callback_list.append(p_save_callback) if config.save_best_model: print("saving best model") best_save_callback = callbacks.ModelCheckpoint(config.save_best_model_path, save_best_only=True) callback_list.append(best_save_callback) model.fit_generator(generator=training_generator, validation_data=validation_generator, use_multiprocessing=True, workers=config.n_workers, max_queue_size=config.queue_size, verbose=1, epochs=config.n_epochs, callbacks=callback_list)
def train(model: keras.models.Model, optimizer: dict, save_path: str, train_dir: str, valid_dir: str, batch_size: int = 32, epochs: int = 10, samples_per_epoch=1000, pretrained=None, augment: bool = True, weight_mode=None, verbose=0, **kwargs): """ Trains the model with the given configurations. """ shape = model.input_shape[1:3] optimizer_cpy = optimizer.copy() shared_gen_args = { 'rescale': 1. / 255, # to preserve the rgb palette } train_gen_args = {} if augment: train_gen_args = { "fill_mode": 'reflect', 'horizontal_flip': True, 'vertical_flip': True, 'width_shift_range': .15, 'height_shift_range': .15, 'shear_range': .5, 'rotation_range': 45, 'zoom_range': .2, } gen = IDG(**{**shared_gen_args, **train_gen_args}) gen = gen.flow_from_directory(train_dir, target_size=shape, batch_size=batch_size, seed=SEED) val_count = len( glob(os.path.join(valid_dir, '**', '*.jpg'), recursive=True)) valid_gen = IDG(**shared_gen_args) optim = getattr(keras.optimizers, optimizer['name']) if optimizer.pop('name') != 'sgd': optimizer.pop('nesterov') schedule = optimizer.pop('schedule') if schedule == 'decay' and 'lr' in optimizer.keys(): initial_lr = optimizer.pop('lr') else: initial_lr = 0.01 optim = optim(**optimizer) callbacks = [ utils.checkpoint(save_path), utils.csv_logger(save_path), ] if pretrained is not None: if not os.path.exists(pretrained): raise FileNotFoundError() model.load_weights(pretrained, by_name=False) if verbose == 1: print("Loaded weights from {}".format(pretrained)) if optimizer_cpy['name'] == 'sgd': if schedule == 'decay': callbacks.append(utils.step_decay(epochs, initial_lr=initial_lr)) elif schedule == 'big_drop': callbacks.append(utils.constant_schedule()) model.compile(optim, loss='categorical_crossentropy', metrics=['accuracy', top3_acc]) create_xml_description(save=os.path.join(save_path, 'model_config.xml'), title=model.name, epochs=epochs, batch_size=batch_size, samples_per_epoch=samples_per_epoch, augmentations=augment, schedule=schedule, optimizer=optimizer_cpy, **kwargs) if weight_mode: class_weights = [[key, value] for key, value in weight_mode.items()] filen = os.path.join(save_path, 'class_weights.npy') np.save(filen, class_weights) h = None # has to be initialized here, so we can reference it later try: with warnings.catch_warnings(): warnings.simplefilter("ignore") h = model.fit_generator( gen, steps_per_epoch=samples_per_epoch / batch_size, epochs=epochs, validation_data=valid_gen.flow_from_directory( valid_dir, target_size=shape, batch_size=batch_size, seed=SEED), validation_steps=val_count / batch_size, callbacks=callbacks, class_weight=weight_mode, verbose=2) except KeyboardInterrupt: save_results(verbose=1, save_path=save_path, model=model, hist=h) return save_results(verbose=1, save_path=save_path, model=model, hist=h)