Ejemplo n.º 1
0
def train(model: keras.models.Model, config: TrainingConfig):
    training_generator, validation_generator = config.get_generators()

    callback_list = list()

    if config.use_tensorboard:
        print("using tensorboard")
        tb_callback = callbacks.TensorBoard(log_dir=config.tensorboard_log_dir,
                                            write_graph=False,
                                            update_freq=5000
                                            )
        callback_list.append(tb_callback)

    if config.reduce_lr_on_plateau:
        print("reducing learning rate on plateau")
        lr_callback = callbacks.ReduceLROnPlateau(
            factor=config.reduce_lr_on_plateau_factor,
            patience=config.reduce_lr_on_plateau_patience,
            cooldown=config.reduce_lr_on_plateau_cooldown,
            min_delta=config.reduce_lr_on_plateau_delta
        )
        callback_list.append(lr_callback)

    if config.save_colored_image_progress:
        print("saving progression every {} epochs".format(config.image_progression_period))
        op_callback = OutputProgress(config.image_paths_to_save,
                                     config.dim_in,
                                     config.image_progression_log_dir,
                                     every_n_epochs=config.image_progression_period)
        callback_list.append(op_callback)

    if config.periodically_save_model:
        print("saving model every {} epcohs".format(config.periodically_save_model_period))
        p_save_callback = callbacks.ModelCheckpoint(config.periodically_save_model_path,
                                                    period=config.periodically_save_model_period)
        callback_list.append(p_save_callback)

    if config.save_best_model:
        print("saving best model")
        best_save_callback = callbacks.ModelCheckpoint(config.save_best_model_path,
                                                       save_best_only=True)
        callback_list.append(best_save_callback)

    model.fit_generator(generator=training_generator,
                        validation_data=validation_generator,
                        use_multiprocessing=True,
                        workers=config.n_workers,
                        max_queue_size=config.queue_size,
                        verbose=1,
                        epochs=config.n_epochs,
                        callbacks=callback_list)
Ejemplo n.º 2
0
def train(model: keras.models.Model,
          optimizer: dict,
          save_path: str,
          train_dir: str,
          valid_dir: str,
          batch_size: int = 32,
          epochs: int = 10,
          samples_per_epoch=1000,
          pretrained=None,
          augment: bool = True,
          weight_mode=None,
          verbose=0,
          **kwargs):
    """ Trains the model with the given configurations. """
    shape = model.input_shape[1:3]
    optimizer_cpy = optimizer.copy()
    shared_gen_args = {
        'rescale': 1. / 255,  # to preserve the rgb palette
    }
    train_gen_args = {}
    if augment:
        train_gen_args = {
            "fill_mode": 'reflect',
            'horizontal_flip': True,
            'vertical_flip': True,
            'width_shift_range': .15,
            'height_shift_range': .15,
            'shear_range': .5,
            'rotation_range': 45,
            'zoom_range': .2,
        }
    gen = IDG(**{**shared_gen_args, **train_gen_args})
    gen = gen.flow_from_directory(train_dir,
                                  target_size=shape,
                                  batch_size=batch_size,
                                  seed=SEED)

    val_count = len(
        glob(os.path.join(valid_dir, '**', '*.jpg'), recursive=True))
    valid_gen = IDG(**shared_gen_args)

    optim = getattr(keras.optimizers, optimizer['name'])
    if optimizer.pop('name') != 'sgd':
        optimizer.pop('nesterov')
    schedule = optimizer.pop('schedule')
    if schedule == 'decay' and 'lr' in optimizer.keys():
        initial_lr = optimizer.pop('lr')
    else:
        initial_lr = 0.01
    optim = optim(**optimizer)

    callbacks = [
        utils.checkpoint(save_path),
        utils.csv_logger(save_path),
    ]

    if pretrained is not None:
        if not os.path.exists(pretrained):
            raise FileNotFoundError()

        model.load_weights(pretrained, by_name=False)
        if verbose == 1:
            print("Loaded weights from {}".format(pretrained))

    if optimizer_cpy['name'] == 'sgd':
        if schedule == 'decay':
            callbacks.append(utils.step_decay(epochs, initial_lr=initial_lr))
        elif schedule == 'big_drop':
            callbacks.append(utils.constant_schedule())

    model.compile(optim,
                  loss='categorical_crossentropy',
                  metrics=['accuracy', top3_acc])

    create_xml_description(save=os.path.join(save_path, 'model_config.xml'),
                           title=model.name,
                           epochs=epochs,
                           batch_size=batch_size,
                           samples_per_epoch=samples_per_epoch,
                           augmentations=augment,
                           schedule=schedule,
                           optimizer=optimizer_cpy,
                           **kwargs)

    if weight_mode:
        class_weights = [[key, value] for key, value in weight_mode.items()]
        filen = os.path.join(save_path, 'class_weights.npy')
        np.save(filen, class_weights)

    h = None  # has to be initialized here, so we can reference it later
    try:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            h = model.fit_generator(
                gen,
                steps_per_epoch=samples_per_epoch / batch_size,
                epochs=epochs,
                validation_data=valid_gen.flow_from_directory(
                    valid_dir,
                    target_size=shape,
                    batch_size=batch_size,
                    seed=SEED),
                validation_steps=val_count / batch_size,
                callbacks=callbacks,
                class_weight=weight_mode,
                verbose=2)
    except KeyboardInterrupt:
        save_results(verbose=1, save_path=save_path, model=model, hist=h)
        return

    save_results(verbose=1, save_path=save_path, model=model, hist=h)