def main(hparams): """ Main training routine specific for this project :param hparams: :return: """ # init experiment exp = Experiment( name=hparams.tt_name, debug=hparams.debug, save_dir=hparams.tt_save_path, version=hparams.hpc_exp_number, autosave=False, description=hparams.tt_description ) exp.argparse(hparams) exp.save() # build model model = LightningTemplateModel(hparams) # callbacks early_stop = EarlyStopping( monitor='val_acc', patience=3, mode='min', verbose=True, ) model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version) checkpoint = ModelCheckpoint( filepath=model_save_path, save_best_only=True, verbose=True, monitor='val_acc', mode='min' ) # configure trainer trainer = Trainer( experiment=exp, checkpoint_callback=checkpoint, early_stop_callback=early_stop, ) # train model trainer.fit(model)
def main(hparams, cluster, results_dict): """ Main training routine specific for this project :param hparams: :return: """ on_gpu = torch.cuda.is_available() if hparams.disable_cuda: on_gpu = False device = 'cuda' if on_gpu else 'cpu' hparams.__setattr__('device', device) hparams.__setattr__('on_gpu', on_gpu) hparams.__setattr__('nb_gpus', torch.cuda.device_count()) hparams.__setattr__('inference_mode', hparams.model_load_weights_path is not None) # delay each training start to not overwrite logs process_position, current_gpu = TRAINING_MODEL.get_process_position( hparams.gpus) sleep(process_position + 1) # init experiment exp = Experiment(name=hparams.tt_name, debug=hparams.debug, save_dir=hparams.tt_save_path, version=hparams.hpc_exp_number, autosave=False, description=hparams.tt_description) exp.argparse(hparams) exp.save() # build model print('loading model...') model = TRAINING_MODEL(hparams) print('model built') # callbacks early_stop = EarlyStopping(monitor=hparams.early_stop_metric, patience=hparams.early_stop_patience, verbose=True, mode=hparams.early_stop_mode) model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version) checkpoint = ModelCheckpoint(filepath=model_save_path, save_function=None, save_best_only=True, verbose=True, monitor=hparams.model_save_monitor_value, mode=hparams.model_save_monitor_mode) # configure trainer trainer = Trainer(experiment=exp, on_gpu=on_gpu, cluster=cluster, progress_bar=hparams.enable_tqdm, overfit_pct=hparams.overfit, track_grad_norm=hparams.track_grad_norm, fast_dev_run=hparams.fast_dev_run, check_val_every_n_epoch=hparams.check_val_every_n_epoch, accumulate_grad_batches=hparams.accumulate_grad_batches, process_position=process_position, current_gpu_name=current_gpu, checkpoint_callback=checkpoint, early_stop_callback=early_stop, enable_early_stop=hparams.enable_early_stop, max_nb_epochs=hparams.max_nb_epochs, min_nb_epochs=hparams.min_nb_epochs, train_percent_check=hparams.train_percent_check, val_percent_check=hparams.val_percent_check, test_percent_check=hparams.test_percent_check, val_check_interval=hparams.val_check_interval, log_save_interval=hparams.log_save_interval, add_log_row_interval=hparams.add_log_row_interval, lr_scheduler_milestones=hparams.lr_scheduler_milestones) # train model trainer.fit(model)