Esempio n. 1
0
def main(hparams):
    """
    Main training routine specific for this project
    :param hparams:
    :return:
    """
    # init experiment
    exp = Experiment(
        name=hparams.tt_name,
        debug=hparams.debug,
        save_dir=hparams.tt_save_path,
        version=hparams.hpc_exp_number,
        autosave=False,
        description=hparams.tt_description
    )

    exp.argparse(hparams)
    exp.save()

    # build model
    model = LightningTemplateModel(hparams)

    # callbacks
    early_stop = EarlyStopping(
        monitor='val_acc',
        patience=3,
        mode='min',
        verbose=True,
    )

    model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
    checkpoint = ModelCheckpoint(
        filepath=model_save_path,
        save_best_only=True,
        verbose=True,
        monitor='val_acc',
        mode='min'
    )

    # configure trainer
    trainer = Trainer(
        experiment=exp,
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stop,
    )

    # train model
    trainer.fit(model)
Esempio n. 2
0
def main(hparams, cluster, results_dict):
    """
    Main training routine specific for this project
    :param hparams:
    :return:
    """
    on_gpu = torch.cuda.is_available()
    if hparams.disable_cuda:
        on_gpu = False

    device = 'cuda' if on_gpu else 'cpu'
    hparams.__setattr__('device', device)
    hparams.__setattr__('on_gpu', on_gpu)
    hparams.__setattr__('nb_gpus', torch.cuda.device_count())
    hparams.__setattr__('inference_mode', hparams.model_load_weights_path
                        is not None)

    # delay each training start to not overwrite logs
    process_position, current_gpu = TRAINING_MODEL.get_process_position(
        hparams.gpus)
    sleep(process_position + 1)

    # init experiment
    exp = Experiment(name=hparams.tt_name,
                     debug=hparams.debug,
                     save_dir=hparams.tt_save_path,
                     version=hparams.hpc_exp_number,
                     autosave=False,
                     description=hparams.tt_description)

    exp.argparse(hparams)
    exp.save()

    # build model
    print('loading model...')
    model = TRAINING_MODEL(hparams)
    print('model built')

    # callbacks
    early_stop = EarlyStopping(monitor=hparams.early_stop_metric,
                               patience=hparams.early_stop_patience,
                               verbose=True,
                               mode=hparams.early_stop_mode)

    model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name,
                                        exp.version)
    checkpoint = ModelCheckpoint(filepath=model_save_path,
                                 save_function=None,
                                 save_best_only=True,
                                 verbose=True,
                                 monitor=hparams.model_save_monitor_value,
                                 mode=hparams.model_save_monitor_mode)

    # configure trainer
    trainer = Trainer(experiment=exp,
                      on_gpu=on_gpu,
                      cluster=cluster,
                      progress_bar=hparams.enable_tqdm,
                      overfit_pct=hparams.overfit,
                      track_grad_norm=hparams.track_grad_norm,
                      fast_dev_run=hparams.fast_dev_run,
                      check_val_every_n_epoch=hparams.check_val_every_n_epoch,
                      accumulate_grad_batches=hparams.accumulate_grad_batches,
                      process_position=process_position,
                      current_gpu_name=current_gpu,
                      checkpoint_callback=checkpoint,
                      early_stop_callback=early_stop,
                      enable_early_stop=hparams.enable_early_stop,
                      max_nb_epochs=hparams.max_nb_epochs,
                      min_nb_epochs=hparams.min_nb_epochs,
                      train_percent_check=hparams.train_percent_check,
                      val_percent_check=hparams.val_percent_check,
                      test_percent_check=hparams.test_percent_check,
                      val_check_interval=hparams.val_check_interval,
                      log_save_interval=hparams.log_save_interval,
                      add_log_row_interval=hparams.add_log_row_interval,
                      lr_scheduler_milestones=hparams.lr_scheduler_milestones)

    # train model
    trainer.fit(model)