Exemplo n.º 1
0
def main():
    """Main function"""

    # Initialization
    args = parse_args()
    rank, n_ranks = init_workers(args.distributed)

    # Load configuration
    config = load_config(args.config)
    train_config = config['training']
    output_dir = os.path.expandvars(config['output_dir'])
    checkpoint_format = os.path.join(output_dir, 'checkpoints',
                                     'checkpoint-{epoch}.h5')
    if rank==0:
        os.makedirs(output_dir, exist_ok=True)

    # Loggging
    config_logging(verbose=args.verbose)
    logging.info('Initialized rank %i out of %i', rank, n_ranks)
    if args.show_config:
        logging.info('Command line config: %s', args)
    if rank == 0:
        logging.info('Job configuration: %s', config)
        logging.info('Saving job outputs to %s', output_dir)

    # Configure session
    device_config = config.get('device', {})
    configure_session(**device_config)

    # Load the data
    train_gen, valid_gen = get_datasets(batch_size=train_config['batch_size'],
                                        **config['data'])

    # Build the model
    model = get_model(**config['model'])
    # Configure optimizer
    opt = get_optimizer(n_ranks=n_ranks, dist_wrapper=hvd.DistributedOptimizer, **config['optimizer'])
    # Compile the model
    model.compile(loss=train_config['loss'], optimizer=opt,
                  metrics=train_config['metrics'])
    if rank == 0:
        model.summary()

    # Prepare the training callbacks
    callbacks = get_basic_callbacks(args.distributed)

    # Learning rate warmup
    warmup_epochs = train_config.get('lr_warmup_epochs', 0)
    callbacks.append(hvd.callbacks.LearningRateWarmupCallback(
                     warmup_epochs=warmup_epochs, verbose=1))

    # Learning rate decay schedule
    for lr_schedule in train_config.get('lr_schedule', []):
        if rank == 0:
            logging.info('Adding LR schedule: %s', lr_schedule)
        callbacks.append(hvd.callbacks.LearningRateScheduleCallback(**lr_schedule))

    # Checkpoint only from rank 0
    if rank == 0:
        os.makedirs(os.path.dirname(checkpoint_format), exist_ok=True)
        callbacks.append(keras.callbacks.ModelCheckpoint(checkpoint_format))
        
    # Timing callback
    timing_callback = TimingCallback()
    callbacks.append(timing_callback)

    # Train the model
    train_steps_per_epoch = max([len(train_gen) // n_ranks, 1])
    valid_steps_per_epoch = max([len(valid_gen) // n_ranks, 1])
    history = model.fit_generator(train_gen,
                                  epochs=train_config['n_epochs'],
                                  steps_per_epoch=train_steps_per_epoch,
                                  validation_data=valid_gen,
                                  validation_steps=valid_steps_per_epoch,
                                  callbacks=callbacks,
                                  workers=4, verbose=2 if rank==0 else 0)

    # Save training history
    if rank == 0:
        # Print some best-found metrics
        if 'val_acc' in history.history.keys():
            logging.info('Best validation accuracy: %.3f',
                         max(history.history['val_acc']))
        if 'val_top_k_categorical_accuracy' in history.history.keys():
            logging.info('Best top-5 validation accuracy: %.3f',
                         max(history.history['val_top_k_categorical_accuracy']))
        logging.info('Average time per epoch: %.3f s',
                     np.mean(timing_callback.times))
        np.savez(os.path.join(output_dir, 'history'),
                 n_ranks=n_ranks, **history.history)

    # Drop to IPython interactive shell
    if args.interactive and (rank == 0):
        logging.info('Starting IPython interactive session')
        import IPython
        IPython.embed()

    if rank == 0:
        logging.info('All done!')
Exemplo n.º 2
0
def main():
    """Main function"""

    # Initialization
    args = parse_args()
    rank, local_rank, n_ranks = init_workers(args.distributed)

    # Load configuration
    config = load_config(args.config)

    # Configure logging
    config_logging(verbose=args.verbose)
    logging.info('Initialized rank %i local_rank %i size %i',
                 rank, local_rank, n_ranks)

    # Device configuration
    configure_session(gpu=local_rank, **config.get('device', {}))

    # Load the data
    train_data, valid_data = get_datasets(rank=rank, n_ranks=n_ranks,
                                          **config['data'])
    if rank == 0:
        logging.info(train_data)
        logging.info(valid_data)

    # Construct the model and optimizer
    model = get_model(**config['model'])
    optimizer = get_optimizer(n_ranks=n_ranks, **config['optimizer'])
    train_config = config['train']

    # Custom metrics for pixel accuracy and IoU
    metrics = [PixelAccuracy(), PixelIoU(name='iou', num_classes=3)]

    # Compile the model
    model.compile(loss=train_config['loss'], optimizer=optimizer,
                  metrics=metrics)

    # Print a model summary
    if rank == 0:
        model.summary()

    # Prepare the callbacks
    callbacks = []

    if args.distributed:

        # Broadcast initial variable states from rank 0 to all processes.
        callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))

        # Average metrics across workers
        callbacks.append(hvd.callbacks.MetricAverageCallback())

        # Learning rate warmup
        warmup_epochs = train_config.get('lr_warmup_epochs', 0)
        callbacks.append(hvd.callbacks.LearningRateWarmupCallback(
            warmup_epochs=warmup_epochs, verbose=1))

    # Timing
    timing_callback = TimingCallback()
    callbacks.append(timing_callback)

    # Checkpointing and CSV logging from rank 0 only
    #if rank == 0:
    #    callbacks.append(tf.keras.callbacks.ModelCheckpoint(checkpoint_format))
    #    callbacks.append(tf.keras.callbacks.CSVLogger(
    #        os.path.join(config['output_dir'], 'history.csv'), append=args.resume))

    if rank == 0:
        logging.debug('Callbacks: %s', callbacks)

    # Train the model
    verbosity = 2 if rank==0 or args.verbose else 0
    history = model.fit(train_data,
                        validation_data=valid_data,
                        epochs=train_config['n_epochs'],
                        callbacks=callbacks,
                        verbose=verbosity)

    # All done
    if rank == 0:
        logging.info('All done!')
Exemplo n.º 3
0
def main():
    """Main function"""

    # Initialization
    args = parse_args()
    dist = init_workers(args.distributed)
    config = load_config(args)
    os.makedirs(config['output_dir'], exist_ok=True)
    config_logging(verbose=args.verbose)
    logging.info('Initialized rank %i size %i local_rank %i local_size %i',
                 dist.rank, dist.size, dist.local_rank, dist.local_size)
    if dist.rank == 0:
        logging.info('Configuration: %s', config)

    # Setup MLPerf logging
    if args.mlperf:
        mllogger = configure_mllogger(config['output_dir'])
    if dist.rank == 0 and args.mlperf:
        mllogger.event(key=mllog.constants.CACHE_CLEAR)
        mllogger.start(key=mllog.constants.INIT_START)

    # Initialize Weights & Biases logging
    if args.wandb and dist.rank == 0:
        import wandb
        wandb.init(project='cosmoflow',
                   name=args.run_tag,
                   id=args.run_tag,
                   config=config,
                   resume=args.run_tag)

    # Device and session configuration
    gpu = dist.local_rank if args.rank_gpu else None
    if gpu is not None:
        logging.info('Taking gpu %i', gpu)
    configure_session(gpu=gpu,
                      intra_threads=args.intra_threads,
                      inter_threads=args.inter_threads,
                      kmp_blocktime=args.kmp_blocktime,
                      kmp_affinity=args.kmp_affinity,
                      omp_num_threads=args.omp_num_threads)

    # Mixed precision
    if args.amp:
        logging.info('Enabling mixed float16 precision')

        # Suggested bug workaround from https://github.com/tensorflow/tensorflow/issues/38516
        if tf.__version__.startswith('2.2.'):
            from tensorflow.python.keras.mixed_precision.experimental import device_compatibility_check
            device_compatibility_check.log_device_compatibility_check = lambda policy_name, skip_local: None
        tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
        # TF 2.3
        #tf.keras.mixed_precision.set_global_policy('mixed_float16')

    # Start MLPerf logging
    if dist.rank == 0 and args.mlperf:
        log_submission_info(**config.get('mlperf', {}))
        mllogger.end(key=mllog.constants.INIT_STOP)
        mllogger.start(key=mllog.constants.RUN_START)

    # Load the data
    data_config = config['data']
    if dist.rank == 0:
        logging.info('Loading data')
    datasets = get_datasets(dist=dist, **data_config)
    logging.debug('Datasets: %s', datasets)

    # Construct or reload the model
    if dist.rank == 0:
        logging.info('Building the model')
    train_config = config['train']
    initial_epoch = 0
    checkpoint_format = os.path.join(config['output_dir'],
                                     'checkpoint-{epoch:03d}.h5')
    if args.resume and os.path.exists(checkpoint_format.format(epoch=1)):
        # Reload model from last checkpoint
        initial_epoch, model = reload_last_checkpoint(
            checkpoint_format,
            data_config['n_epochs'],
            distributed=args.distributed)
    else:
        # Build a new model
        model = get_model(**config['model'])
        # Configure the optimizer
        opt = get_optimizer(distributed=args.distributed,
                            **config['optimizer'])
        # Compile the model
        model.compile(optimizer=opt,
                      loss=train_config['loss'],
                      metrics=train_config['metrics'])

    if dist.rank == 0:
        model.summary()

    # Save configuration to output directory
    if dist.rank == 0:
        config['n_ranks'] = dist.size
        save_config(config)

    # Prepare the callbacks
    if dist.rank == 0:
        logging.info('Preparing callbacks')
    callbacks = []
    if args.distributed:

        # Broadcast initial variable states from rank 0 to all processes.
        callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))

        # Average metrics across workers
        callbacks.append(hvd.callbacks.MetricAverageCallback())

    # Learning rate decay schedule
    if 'lr_schedule' in config:
        global_batch_size = data_config['batch_size'] * dist.size
        callbacks.append(
            tf.keras.callbacks.LearningRateScheduler(
                get_lr_schedule(global_batch_size=global_batch_size,
                                **config['lr_schedule'])))

    # Timing
    timing_callback = TimingCallback()
    callbacks.append(timing_callback)

    # Checkpointing and logging from rank 0 only
    if dist.rank == 0:
        callbacks.append(tf.keras.callbacks.ModelCheckpoint(checkpoint_format))
        callbacks.append(
            tf.keras.callbacks.CSVLogger(os.path.join(config['output_dir'],
                                                      'history.csv'),
                                         append=args.resume))
        if args.tensorboard:
            callbacks.append(
                tf.keras.callbacks.TensorBoard(
                    os.path.join(config['output_dir'], 'tensorboard')))
        if args.mlperf:
            callbacks.append(MLPerfLoggingCallback())
        if args.wandb:
            callbacks.append(wandb.keras.WandbCallback())

    # Early stopping
    patience = train_config.get('early_stopping_patience', None)
    if patience is not None:
        callbacks.append(
            tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                             min_delta=1e-5,
                                             patience=patience,
                                             verbose=1))

    # Stopping at specified target
    target_mae = train_config.get('target_mae', None)
    callbacks.append(StopAtTargetCallback(target_max=target_mae))

    if dist.rank == 0:
        logging.debug('Callbacks: %s', callbacks)

    # Train the model
    if dist.rank == 0:
        logging.info('Beginning training')
    fit_verbose = 1 if (args.verbose and dist.rank == 0) else 2
    model.fit(datasets['train_dataset'],
              steps_per_epoch=datasets['n_train_steps'],
              epochs=data_config['n_epochs'],
              validation_data=datasets['valid_dataset'],
              validation_steps=datasets['n_valid_steps'],
              callbacks=callbacks,
              initial_epoch=initial_epoch,
              verbose=fit_verbose)

    # Stop MLPerf timer
    if dist.rank == 0 and args.mlperf:
        mllogger.end(key=mllog.constants.RUN_STOP,
                     metadata={'status': 'success'})

    # Print training summary
    if dist.rank == 0:
        print_training_summary(config['output_dir'], args.print_fom)

    # Print GPU memory - not supported in TF 2.2?
    #if gpu is not None:
    #    device = tf.config.list_physical_devices('GPU')[gpu]
    #    #print(tf.config.experimental.get_memory_usage(device))
    #    #print(tf.config.experimental.get_memory_info(device))

    # Finalize
    if dist.rank == 0:
        logging.info('All done!')
Exemplo n.º 4
0
# Configure optimizer
# opt = get_optimizer(n_ranks=n_ranks, distributed=False,
#                     **config['optimizer'])

# Compile the model
model.compile(loss=train_config['loss'], optimizer=config['optimizer']['name'],#opt
              metrics=train_config['metrics'])
train_gen, valid_gen = get_datasets(batch_size=train_config['batch_size'],
                                    **config['data_and_model'],**config['data'])

steps_per_epoch = len(train_gen) // n_ranks

# Timing
callbacks = []
timing_callback = TimingCallback()
callbacks.append(timing_callback)
callbacks.append(keras.callbacks.EarlyStopping(monitor='val_loss',patience=5))

callbacks.append(keras.callbacks.ModelCheckpoint(filepath=os.path.join(output_dir, output_file_name),
                                                 monitor='val_mean_absolute_error',
                                                 save_best_only=True,
                                                 verbose=1))

history = model.fit_generator(train_gen,
                              epochs=train_config['n_epochs'],
                              steps_per_epoch=steps_per_epoch,
                              validation_data=valid_gen,
                              validation_steps=len(valid_gen),
                              callbacks=callbacks,
                              workers=4, verbose=1)
Exemplo n.º 5
0
def main():
    """Main function"""

    # Initialization
    args = parse_args()
    rank, local_rank, n_ranks = init_workers(args.distributed)
    config = load_config(args.config,
                         output_dir=args.output_dir,
                         data_config=args.data_config)

    os.makedirs(config['output_dir'], exist_ok=True)
    config_logging(verbose=args.verbose)
    logging.info('Initialized rank %i local_rank %i size %i', rank, local_rank,
                 n_ranks)
    if rank == 0:
        logging.info('Configuration: %s', config)

    # Device and session configuration
    gpu = local_rank if args.rank_gpu else None
    configure_session(gpu=gpu, **config.get('device', {}))

    # Load the data
    data_config = config['data']
    if rank == 0:
        logging.info('Loading data')
    datasets = get_datasets(rank=rank, n_ranks=n_ranks, **data_config)
    logging.debug('Datasets: %s', datasets)

    # Construct or reload the model
    if rank == 0:
        logging.info('Building the model')
    initial_epoch = 0
    checkpoint_format = os.path.join(config['output_dir'],
                                     'checkpoint-{epoch:03d}.h5')
    if args.resume:
        # Reload model from last checkpoint
        initial_epoch, model = reload_last_checkpoint(
            checkpoint_format,
            data_config['n_epochs'],
            distributed=args.distributed)
    else:
        # Build a new model
        model = get_model(**config['model'])
        # Configure the optimizer
        opt = get_optimizer(n_ranks=n_ranks,
                            distributed=args.distributed,
                            **config['optimizer'])
        # Compile the model
        train_config = config['train']
        model.compile(optimizer=opt,
                      loss=train_config['loss'],
                      metrics=train_config['metrics'])

    if rank == 0:
        model.summary()

    # Save configuration to output directory
    if rank == 0:
        data_config['n_train'] = datasets['n_train']
        data_config['n_valid'] = datasets['n_valid']
        save_config(config)

    # Prepare the callbacks
    if rank == 0:
        logging.info('Preparing callbacks')
    callbacks = []
    if args.distributed:

        # Broadcast initial variable states from rank 0 to all processes.
        callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))

        # Average metrics across workers
        callbacks.append(hvd.callbacks.MetricAverageCallback())

        # Learning rate warmup
        train_config = config['train']
        warmup_epochs = train_config.get('lr_warmup_epochs', 0)
        callbacks.append(
            hvd.callbacks.LearningRateWarmupCallback(
                warmup_epochs=warmup_epochs, verbose=1))

    # Learning rate decay schedule
    lr_schedule = train_config.get('lr_schedule', {})
    if rank == 0:
        logging.info('Adding LR decay schedule: %s', lr_schedule)
    callbacks.append(
        tf.keras.callbacks.LearningRateScheduler(
            schedule=lambda epoch, lr: lr * lr_schedule.get(epoch, 1)))

    # Timing
    timing_callback = TimingCallback()
    callbacks.append(timing_callback)

    # Checkpointing and CSV logging from rank 0 only
    if rank == 0:
        callbacks.append(tf.keras.callbacks.ModelCheckpoint(checkpoint_format))
        callbacks.append(
            tf.keras.callbacks.CSVLogger(os.path.join(config['output_dir'],
                                                      'history.csv'),
                                         append=args.resume))

    if rank == 0:
        logging.debug('Callbacks: %s', callbacks)

    # Train the model
    if rank == 0:
        logging.info('Beginning training')
    fit_verbose = 1 if (args.verbose and rank == 0) else 2
    model.fit(datasets['train_dataset'],
              steps_per_epoch=datasets['n_train_steps'],
              epochs=data_config['n_epochs'],
              validation_data=datasets['valid_dataset'],
              validation_steps=datasets['n_valid_steps'],
              callbacks=callbacks,
              initial_epoch=initial_epoch,
              verbose=fit_verbose)

    # Print training summary
    if rank == 0:
        print_training_summary(config['output_dir'])

    # Finalize
    if rank == 0:
        logging.info('All done!')
Exemplo n.º 6
0
def main():
    """Main function"""

    # Initialization
    args = parse_args()
    dist = init_workers(args.distributed)
    config = load_config(args)
    os.makedirs(config['output_dir'], exist_ok=True)
    config_logging(verbose=args.verbose)
    logging.info('Initialized rank %i size %i local_rank %i local_size %i',
                 dist.rank, dist.size, dist.local_rank, dist.local_size)
    if dist.rank == 0:
        logging.info('Configuration: %s', config)

    # Device and session configuration
    gpu = dist.local_rank if args.rank_gpu else None
    if gpu is not None:
        logging.info('Taking gpu %i', gpu)
    configure_session(gpu=gpu,
                      intra_threads=args.intra_threads,
                      inter_threads=args.inter_threads,
                      kmp_blocktime=args.kmp_blocktime,
                      kmp_affinity=args.kmp_affinity,
                      omp_num_threads=args.omp_num_threads)

    # Load the data
    data_config = config['data']
    if dist.rank == 0:
        logging.info('Loading data')
    datasets = get_datasets(dist=dist, **data_config)
    logging.debug('Datasets: %s', datasets)

    # Construct or reload the model
    if dist.rank == 0:
        logging.info('Building the model')
    train_config = config['train']
    initial_epoch = 0
    checkpoint_format = os.path.join(config['output_dir'],
                                     'checkpoint-{epoch:03d}.h5')
    if args.resume and os.path.exists(checkpoint_format.format(epoch=1)):
        # Reload model from last checkpoint
        initial_epoch, model = reload_last_checkpoint(
            checkpoint_format,
            data_config['n_epochs'],
            distributed=args.distributed)
    else:
        # Build a new model
        model = get_model(**config['model'])
        # Configure the optimizer
        opt = get_optimizer(distributed=args.distributed,
                            **config['optimizer'])
        # Compile the model
        model.compile(optimizer=opt,
                      loss=train_config['loss'],
                      metrics=train_config['metrics'])

    if dist.rank == 0:
        model.summary()

    # Save configuration to output directory
    if dist.rank == 0:
        config['n_ranks'] = dist.size
        save_config(config)

    # Prepare the callbacks
    if dist.rank == 0:
        logging.info('Preparing callbacks')
    callbacks = []
    if args.distributed:

        # Broadcast initial variable states from rank 0 to all processes.
        callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))

        # Average metrics across workers
        callbacks.append(hvd.callbacks.MetricAverageCallback())

    # Learning rate decay schedule
    if 'lr_schedule' in config:
        global_batch_size = data_config['batch_size'] * dist.size
        callbacks.append(
            tf.keras.callbacks.LearningRateScheduler(
                get_lr_schedule(global_batch_size=global_batch_size,
                                **config['lr_schedule'])))

    # Timing
    timing_callback = TimingCallback()
    callbacks.append(timing_callback)

    # Checkpointing and logging from rank 0 only
    if dist.rank == 0:
        callbacks.append(tf.keras.callbacks.ModelCheckpoint(checkpoint_format))
        callbacks.append(
            tf.keras.callbacks.CSVLogger(os.path.join(config['output_dir'],
                                                      'history.csv'),
                                         append=args.resume))
        if args.tensorboard:
            callbacks.append(
                tf.keras.callbacks.TensorBoard(
                    os.path.join(config['output_dir'], 'tensorboard')))

    # Early stopping
    patience = config.get('early_stopping_patience', None)
    if patience is not None:
        callbacks.append(
            tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                             min_delta=1e-5,
                                             patience=patience,
                                             verbose=1))

    if dist.rank == 0:
        logging.debug('Callbacks: %s', callbacks)

    # Train the model
    if dist.rank == 0:
        logging.info('Beginning training')
    fit_verbose = 1 if (args.verbose and dist.rank == 0) else 2
    model.fit(datasets['train_dataset'],
              steps_per_epoch=datasets['n_train_steps'],
              epochs=data_config['n_epochs'],
              validation_data=datasets['valid_dataset'],
              validation_steps=datasets['n_valid_steps'],
              callbacks=callbacks,
              initial_epoch=initial_epoch,
              verbose=fit_verbose)

    # Print training summary
    if dist.rank == 0:
        print_training_summary(config['output_dir'], args.print_fom)

    # Finalize
    if dist.rank == 0:
        logging.info('All done!')
Exemplo n.º 7
0
def main():
    """Main function"""
    # Initialization
    args = parse_args()
    rank, n_ranks = init_workers(args.distributed)

    # Load configuration
    config = load_config(args.config)
    train_config = config['training']
    output_dir = os.path.expandvars(config['output_dir'])
    checkpoint_format = os.path.join(output_dir, 'checkpoints',
                                     'checkpoint-{epoch}.h5')
    os.makedirs(output_dir, exist_ok=True)

    # Logging
    config_logging(verbose=args.verbose, output_dir=output_dir)
    logging.info('Initialized rank %i out of %i', rank, n_ranks)
    if args.show_config:
        logging.info('Command line config: %s', args)
    if rank == 0:
        logging.info('Job configuration: %s', config)
        logging.info('Saving job outputs to %s', output_dir)

    # Configure session
    if args.distributed:
        gpu = hvd.local_rank()
    else:
        gpu = args.gpu
    device_config = config.get('device', {})
    configure_session(gpu=gpu, **device_config)

    # Load the data
    train_gen, valid_gen = get_datasets(batch_size=train_config['batch_size'],
                                        **config['data_and_model'],
                                        **config['data'])

    # Build the model
    # if (type(config['data']['n_components']) is int):
    #     rho_length_in = config['data']['n_components']
    # else:
    rho_length_in = config['model']['rho_length_out']

    model = get_model(rho_length_in=rho_length_in, 
                      **config['data_and_model'],
                      **config['model'])
    # Configure optimizer
    opt = get_optimizer(n_ranks=n_ranks, distributed=args.distributed,
                        **config['optimizer'])
    # Compile the model
    model.compile(loss=train_config['loss'], optimizer=opt,
                  metrics=train_config['metrics'])
    if rank == 0:
        model.summary()

    # Prepare the training callbacks
    callbacks = []
    if args.distributed:

        # Broadcast initial variable states from rank 0 to all processes.
        callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))

        # # Learning rate warmup
        # warmup_epochs = train_config.('lr_warmup_epochs', 0)
        # callbacks.append(hvd.callbacks.LearningRateWarmupCallback(
        #     warmup_epochs=warmup_epochs, verbose=1))

        # # Learning rate decay schedule
        # for lr_schedule in train_config.get('lr_schedule', []):
        #     if rank == 0:
        #         logging.info('Adding LR schedule: %s', lr_schedule)
        #     callbacks.append(hvd.callbacks.LearningRateScheduleCallback(**lr_schedule))

    # Checkpoint only from rank 0
    if rank == 0:
        #os.makedirs(os.path.dirname(checkpoint_format), exist_ok=True)
        #callbacks.append(keras.callbacks.ModelCheckpoint(checkpoint_format))
        #callbacks.append(keras.callbacks.EarlyStopping(monitor='val_loss',
        #                                           patience=5))
        callbacks.append(keras.callbacks.ModelCheckpoint(filepath=os.path.join(output_dir, 'model.h5'),
                                                         monitor='val_mean_absolute_error',
                                                         save_best_only=False,
                                                         verbose=2))


    # Timing
    timing_callback = TimingCallback()
    callbacks.append(timing_callback)

    # Train the model
    steps_per_epoch = len(train_gen) // n_ranks 
#     import pdb
#     pdb.set_trace()
    
    history = model.fit_generator(train_gen,
                                  epochs=train_config['n_epochs'],
                                  steps_per_epoch=steps_per_epoch,
                                  validation_data=valid_gen,
                                  validation_steps=len(valid_gen),
                                  callbacks=callbacks,
                                  workers=4, verbose=1)

    # Save training history
    if rank == 0:
        # Print some best-found metrics
        if 'val_acc' in history.history.keys():
            logging.info('Best validation accuracy: %.3f',
                         max(history.history['val_acc']))
        if 'val_top_k_categorical_accuracy' in history.history.keys():
            logging.info('Best top-5 validation accuracy: %.3f',
                         max(history.history['val_top_k_categorical_accuracy']))
        if 'val_mean_absolute_error' in history.history.keys():
            logging.info('Best validation mae: %.3f',
                         min(history.history['val_mean_absolute_error']))

        

        logging.info('Average time per epoch: %.3f s',
                     np.mean(timing_callback.times))
        np.savez(os.path.join(output_dir, 'history'),
                 n_ranks=n_ranks, **history.history)

    # Drop to IPython interactive shell
    if args.interactive and (rank == 0):
        logging.info('Starting IPython interactive session')
        import IPython
        IPython.embed()

    if rank == 0:
        logging.info('All done!')