def train_and_evaluate(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader, optimizer: optim, args) -> None: logger.info('begin training and evaluation') best_test_R2 = float('inf') train_len = len(train_loader) loss_summary = np.zeros((train_len * args.num_epochs)) R2_summary = np.zeros(args.num_epochs) for epoch in range(args.num_epochs): logger.info('Epoch {}/{}'.format(epoch + 1, args.num_epochs)) loss_summary[epoch * train_len:(epoch + 1) * train_len] = train( model, optimizer, train_loader, test_loader, args, epoch) test_metrics = evaluate(model, test_loader, args, epoch) R2_summary[epoch] = test_metrics['R2'] is_best = R2_summary[epoch] <= best_test_R2 # Save weights utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict() }, epoch=epoch, is_best=is_best, checkpoint=args.model_dir) if is_best: logger.info('- Found new best R2') best_test_R2 = R2_summary[epoch] best_json_path = os.path.join(args.model_dir, 'metrics_test_best_weights.json') utils.save_dict_to_json(test_metrics, best_json_path) logger.info('Current Best R2 is: %.5f' % best_test_R2) utils.plot_all_epoch(R2_summary[:epoch + 1], args.dataset + '_ND', args.plot_dir) utils.plot_all_epoch(loss_summary[:(epoch + 1) * train_len], args.dataset + '_loss', args.plot_dir) last_json_path = os.path.join(args.model_dir, 'metrics_test_last_weights.json') utils.save_dict_to_json(test_metrics, last_json_path)
def train_and_evaluate(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader, optimizer: optim, loss_fn, params: utils.Params, restore_file: str = None) -> None: '''Train the model and evaluate every epoch. Args: model: (torch.nn.Module) the Deep AR model train_loader: load train data and labels test_loader: load test data and labels optimizer: (torch.optim) optimizer for parameters of model loss_fn: a function that takes outputs and labels per timestep, and then computes the loss for the batch params: (Params) hyperparameters restore_file: (string) optional- name of file to restore from (without its extension .pth.tar) ''' # reload weights from restore_file if specified if restore_file is not None: restore_path = os.path.join(params.model_dir, restore_file + '.pth.tar') logger.info('Restoring parameters from {}'.format(restore_path)) utils.load_checkpoint(restore_path, model, optimizer) logger.info('begin training and evaluation') best_test_ND = float('inf') train_len = len(train_loader) ND_summary = np.zeros(params.num_epochs) loss_summary = np.zeros((train_len * params.num_epochs)) for epoch in range(params.num_epochs): logger.info('Epoch {}/{}'.format(epoch + 1, params.num_epochs)) loss_summary[epoch * train_len:(epoch + 1) * train_len] = train( model, optimizer, loss_fn, train_loader, test_loader, params, epoch) test_metrics = evaluate(model, loss_fn, test_loader, params, epoch, sample=args.sampling) ND_summary[epoch] = test_metrics['ND'] is_best = ND_summary[epoch] <= best_test_ND # Save weights utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict() }, epoch=epoch, is_best=is_best, checkpoint=params.model_dir) if is_best: logger.info('- Found new best ND') best_test_ND = ND_summary[epoch] best_json_path = os.path.join(params.model_dir, 'metrics_test_best_weights.json') utils.save_dict_to_json(test_metrics, best_json_path) logger.info('Current Best ND is: %.5f' % best_test_ND) utils.plot_all_epoch(ND_summary[:epoch + 1], args.dataset + '_ND', params.plot_dir) utils.plot_all_epoch(loss_summary[:(epoch + 1) * train_len], args.dataset + '_loss', params.plot_dir) last_json_path = os.path.join(params.model_dir, 'metrics_test_last_weights.json') utils.save_dict_to_json(test_metrics, last_json_path) if args.save_best: f = open('./param_search.txt', 'w') f.write('-----------\n') list_of_params = args.search_params.split(',') print_params = '' for param in list_of_params: param_value = getattr(params, param) print_params += f'{param}: {param_value:.2f}' print_params = print_params[:-1] f.write(print_params + '\n') f.write('Best ND: ' + str(best_test_ND) + '\n') logger.info(print_params) logger.info(f'Best ND: {best_test_ND}') f.close() utils.plot_all_epoch(ND_summary, print_params + '_ND', location=params.plot_dir) utils.plot_all_epoch(loss_summary, print_params + '_loss', location=params.plot_dir)
def train_and_evaluate(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader, optimizer: optim, loss_fn, params: utils.Params, restore_file: str = None) -> None: '''Train the model and evaluate every epoch. Args: model: (torch.nn.Module) the Deep AR model train_loader: load train data and labels test_loader: load test data and labels optimizer: (torch.optim) optimizer for parameters of model loss_fn: a function that takes outputs and labels per timestep, and then computes the loss for the batch params: (Params) hyperparameters restore_file: (string) optional- name of file to restore from (without its extension .pth.tar) ''' # reload weights from restore_file if specified restore_epoch = 0 if restore_file is not None: restore_path = os.path.join(params.model_dir, restore_file + '.pth.tar') logger.info('Restoring parameters from {}'.format(restore_path)) utils.load_checkpoint(restore_path, model, optimizer) restore_epoch = int(restore_file[-2:].replace('_',''))+1 logger.info('Restoring epoch: {}'.format(restore_epoch)) logger.info('Begin training and evaluation') # initialize the early_stopping object early_stopping = EarlyStopping(patience=25, verbose=True, delta=0.0001, folder=params.model_dir) if os.path.exists(os.path.join(params.model_dir, 'metrics_test_best_weights.json')): with open(os.path.join(params.model_dir, 'metrics_test_best_weights.json')) as json_file: best_test_ND = json.load(json_file)['ND'] early_stopping.best_score = best_test_ND else: best_test_ND = float('inf') early_stopping.best_score = best_test_ND train_len = len(train_loader) ND_summary = np.zeros(params.num_epochs) loss_summary = np.zeros((train_len * params.num_epochs)) for epoch in range(restore_epoch, params.num_epochs): logger.info('Epoch {}/{}'.format(epoch + 1, params.num_epochs)) loss_summary[epoch * train_len:(epoch + 1) * train_len] = train(model, optimizer, loss_fn, train_loader, test_loader, params, epoch) test_metrics = evaluate(model, loss_fn, test_loader, params, epoch, sample=args.sampling) # if test_metrics['ND'] == float('nan'): # test_metrics['ND'] = 1000 # print('NAN ') # elif test_metrics['ND'] == np.nan: # print('NAN ') # test_metrics['ND'] = 1000 ND_summary[epoch] = test_metrics['ND'] ##################################'ND' is_best = ND_summary[epoch] <= best_test_ND # Save weights utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict()}, epoch=epoch, is_best=is_best, checkpoint=params.model_dir) if is_best: logger.info('- Found new best ND') ############# 'ND' best_test_ND = ND_summary[epoch] best_json_path = os.path.join(params.model_dir, 'metrics_test_best_weights.json') utils.save_dict_to_json(test_metrics, best_json_path) logger.info('Current Best ND is: %.5f' % best_test_ND) ## 'ND' utils.plot_all_epoch(ND_summary[:epoch + 1], args.dataset + '_ND', params.plot_dir) utils.plot_all_epoch(loss_summary[:(epoch + 1) * train_len], args.dataset + '_loss', params.plot_dir) last_json_path = os.path.join(params.model_dir, 'metrics_test_last_weights.json') utils.save_dict_to_json(test_metrics, last_json_path) # early_stopping needs the validation loss to check if it has decresed, # and if it has, it will make a checkpoint of the current model logger.info('ND : %.5f ' % test_metrics['ND']) early_stopping(test_metrics['ND'], model) if early_stopping.early_stop: logger.info('Early stopping') break # # load the last checkpoint with the best model # model.load_state_dict(torch.load('checkpoint.pt')) if args.save_best: f = open('./param_search.txt', 'w') f.write('-----------\n') list_of_params = args.search_params.split(',') print_params = '' for param in list_of_params: param_value = getattr(params, param) print_params += f'{param}: {param_value:.2f}' print_params = print_params[:-1] f.write(print_params + '\n') f.write('Best ND: ' + str(best_test_ND) + '\n') logger.info(print_params) logger.info(f'Best ND: {best_test_ND}') f.close() utils.plot_all_epoch(ND_summary, print_params + '_ND', location=params.plot_dir) utils.plot_all_epoch(loss_summary, print_params + '_loss', location=params.plot_dir)