Пример #1
0
with open(str(save_dir / ("config_%d.json" % config.initial_epoch)),
          "w") as fp:
    json.dump(vars(config), fp, sort_keys=True, indent=4)

# Set up training callbacks
checkpoint = ModelCheckpoint(filepath=str(file_path),
                             monitor="val_acc",
                             verbose=1)
tensor_board = TensorBoard(log_dir=save_dir,
                           histogram_freq=0,
                           batch_size=config.batch_size,
                           write_graph=True,
                           write_grads=False,
                           write_images=False,
                           update_freq=2500)
tensor_board.samples_seen = config.initial_epoch * len(x_train)
tensor_board.samples_seen_at_last_write = config.initial_epoch * len(x_train)
lr_reducer = ReduceLROnPlateau(monitor="val_loss",
                               factor=config.lr_reduce_factor,
                               cooldown=0,
                               patience=config.lr_reduce_patience,
                               min_lr=config.lr_reduce_min,
                               verbose=1)
early_stopping = EarlyStopping(monitor="val_loss",
                               patience=config.early_stop_patience,
                               verbose=1)

callbacks = [checkpoint, tensor_board]
if (config.lr_reduce):
    callbacks.append(lr_reducer)
if (config.early_stop):
Пример #2
0
batch_size = 256

epochs = epoch + args.epochs

train = PositionSequence(training_database,
                         batch_size=batch_size,
                         rotate=epoch * config["batches_per_epoch"])
val = PositionSequence(validation_database, batch_size=batch_size)
extra: Dict[str, PositionSequence] = {
    n: PositionSequence(db, batch_size=batch_size)
    for n, db in extra_databases.items()
}

tb = TensorBoard(tensorboard_dir,
                 update_freq=batch_size * config["batches_per_epoch"] / 4)
tb.samples_seen = epoch * config["batches_per_epoch"] * batch_size
tb.samples_seen_at_last_write = tb.samples_seen
callbacks = [
    LambdaCallback(on_epoch_end=save_model),
    LambdaCallback(on_epoch_end=validate_extra_datasets), tb
]

network.training_model.fit_generator(
    train,
    epochs=epochs,
    callbacks=callbacks,
    verbose=2 if not is_atty else 1,
    validation_data=val,
    shuffle=False,
    workers=1,
    use_multiprocessing=True,