def main(args): dataset_name = args.dataset model_name = args.model n_inner_iter = args.adaptation_steps batch_size = args.batch_size save_model_file = args.save_model_file load_model_file = args.load_model_file lower_trial = args.lower_trial upper_trial = args.upper_trial is_test = args.is_test stopping_patience = args.stopping_patience epochs = args.epochs fast_lr = args.learning_rate slow_lr = args.meta_learning_rate noise_level = args.noise_level noise_type = args.noise_type resume = args.resume first_order = False inner_loop_grad_clip = 20 task_size = 50 output_dim = 1 checkpoint_freq = 10 horizon = 10 ##test meta_info = { "POLLUTION": [5, 50, 14], "HR": [32, 50, 13], "BATTERY": [20, 50, 3] } assert model_name in ("FCN", "LSTM"), "Model was not correctly specified" assert dataset_name in ("POLLUTION", "HR", "BATTERY") window_size, task_size, input_dim = meta_info[dataset_name] grid = [0., noise_level] output_directory = "output/" train_data_ML = pickle.load( open( "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb")) validation_data_ML = pickle.load( open( "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb")) test_data_ML = pickle.load( open( "../../Data/TEST-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb")) for trial in range(lower_trial, upper_trial): output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MAML/" + str( trial) + "/" save_model_file_ = output_directory + save_model_file save_model_file_encoder = output_directory + "encoder_" + save_model_file load_model_file_ = output_directory + load_model_file checkpoint_file = output_directory + "checkpoint_" + save_model_file.split( ".")[0] try: os.mkdir(output_directory) except OSError as error: print(error) with open(output_directory + "/results2.txt", "a+") as f: f.write("Learning rate :%f \n" % fast_lr) f.write("Meta-learning rate: %f \n" % slow_lr) f.write("Adaptation steps: %f \n" % n_inner_iter) f.write("Noise level: %f \n" % noise_level) if model_name == "LSTM": model = LSTMModel(batch_size=batch_size, seq_len=window_size, input_dim=input_dim, n_layers=2, hidden_dim=120, output_dim=output_dim) model2 = LinearModel(120, 1) optimizer = torch.optim.Adam(list(model.parameters()) + list(model2.parameters()), lr=slow_lr) loss_func = mae #loss_func = nn.SmoothL1Loss() #loss_func = nn.MSELoss() initial_epoch = 0 #torch.backends.cudnn.enabled = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") meta_learner = MetaLearner(model2, optimizer, fast_lr, loss_func, first_order, n_inner_iter, inner_loop_grad_clip, device) model.to(device) early_stopping = EarlyStopping(patience=stopping_patience, model_file=save_model_file_encoder, verbose=True) early_stopping2 = EarlyStopping(patience=stopping_patience, model_file=save_model_file_, verbose=True) if resume: checkpoint = torch.load(checkpoint_file) model.load_state_dict(checkpoint["model"]) meta_learner.load_state_dict(checkpoint["meta_learner"]) initial_epoch = checkpoint["epoch"] best_score = checkpoint["best_score"] counter = checkpoint["counter_stopping"] early_stopping.best_score = best_score early_stopping2.best_score = best_score early_stopping.counter = counter early_stopping2.counter = counter total_tasks, task_size, window_size, input_dim = train_data_ML.x.shape accum_mean = 0.0 for epoch in range(initial_epoch, epochs): model.zero_grad() meta_learner._model.zero_grad() #train batch_idx = np.random.randint(0, total_tasks - 1, batch_size) #for batch_idx in range(0, total_tasks-1, batch_size): x_spt, y_spt = train_data_ML[batch_idx] x_qry, y_qry = train_data_ML[batch_idx + 1] x_spt, y_spt = to_torch(x_spt), to_torch(y_spt) x_qry = to_torch(x_qry) y_qry = to_torch(y_qry) # data augmentation epsilon = grid[np.random.randint(0, len(grid))] if noise_type == "additive": y_spt = y_spt + epsilon y_qry = y_qry + epsilon else: y_spt = y_spt * (1 + epsilon) y_qry = y_qry * (1 + epsilon) train_tasks = [ Task(model.encoder(x_spt[i]), y_spt[i]) for i in range(x_spt.shape[0]) ] val_tasks = [ Task(model.encoder(x_qry[i]), y_qry[i]) for i in range(x_qry.shape[0]) ] adapted_params = meta_learner.adapt(train_tasks) mean_loss = meta_learner.step(adapted_params, val_tasks, is_training=True) #accum_mean += mean_loss.cpu().detach().numpy() #progressBar(batch_idx, total_tasks, 100) #print(accum_mean/(batch_idx+1)) #test val_error = test(validation_data_ML, meta_learner, model, device, noise_level) test_error = test(test_data_ML, meta_learner, model, device, 0.0) print("Epoch:", epoch) print("Val error:", val_error) print("Test error:", test_error) early_stopping(val_error, model) early_stopping2(val_error, meta_learner) #checkpointing if epochs % checkpoint_freq == 0: torch.save( { "epoch": epoch, "model": model.state_dict(), "meta_learner": meta_learner.state_dict(), "best_score": early_stopping2.best_score, "counter_stopping": early_stopping2.counter }, checkpoint_file) if early_stopping.early_stop: print("Early stopping") break print("hallo") model.load_state_dict(torch.load(save_model_file_encoder)) model2.load_state_dict( torch.load(save_model_file_)["model_state_dict"]) meta_learner = MetaLearner(model2, optimizer, fast_lr, loss_func, first_order, n_inner_iter, inner_loop_grad_clip, device) validation_error = test(validation_data_ML, meta_learner, model, device, noise_level=0.0) test_error = test(test_data_ML, meta_learner, model, device, noise_level=0.0) validation_error_h1 = test(validation_data_ML, meta_learner, model, device, noise_level=0.0, horizon=1) test_error_h1 = test(test_data_ML, meta_learner, model, device, noise_level=0.0, horizon=1) model.load_state_dict(torch.load(save_model_file_encoder)) model2.load_state_dict( torch.load(save_model_file_)["model_state_dict"]) meta_learner2 = MetaLearner(model2, optimizer, fast_lr, loss_func, first_order, 0, inner_loop_grad_clip, device) validation_error_h0 = test(validation_data_ML, meta_learner2, model, device, noise_level=0.0, horizon=1) test_error_h0 = test(test_data_ML, meta_learner2, model, device, noise_level=0.0, horizon=1) model.load_state_dict(torch.load(save_model_file_encoder)) model2.load_state_dict( torch.load(save_model_file_)["model_state_dict"]) meta_learner2 = MetaLearner(model2, optimizer, fast_lr, loss_func, first_order, n_inner_iter, inner_loop_grad_clip, device) validation_error_mae = test(validation_data_ML, meta_learner2, model, device, 0.0) test_error_mae = test(test_data_ML, meta_learner2, model, device, 0.0) print("test_error_mae", test_error_mae) with open(output_directory + "/results2.txt", "a+") as f: f.write("Test error: %f \n" % test_error) f.write("Validation error: %f \n" % validation_error) f.write("Test error h1: %f \n" % test_error_h1) f.write("Validation error h1: %f \n" % validation_error_h1) f.write("Test error h0: %f \n" % test_error_h0) f.write("Validation error h0: %f \n" % validation_error_h0) f.write("Test error mae: %f \n" % test_error_mae) f.write("Validation error mae: %f \n" % validation_error_mae) print(test_error) print(validation_error)
def train_and_evaluate(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader, optimizer: optim, loss_fn, params: utils.Params, restore_file: str = None) -> None: '''Train the model and evaluate every epoch. Args: model: (torch.nn.Module) the Deep AR model train_loader: load train data and labels test_loader: load test data and labels optimizer: (torch.optim) optimizer for parameters of model loss_fn: a function that takes outputs and labels per timestep, and then computes the loss for the batch params: (Params) hyperparameters restore_file: (string) optional- name of file to restore from (without its extension .pth.tar) ''' # reload weights from restore_file if specified restore_epoch = 0 if restore_file is not None: restore_path = os.path.join(params.model_dir, restore_file + '.pth.tar') logger.info('Restoring parameters from {}'.format(restore_path)) utils.load_checkpoint(restore_path, model, optimizer) restore_epoch = int(restore_file[-2:].replace('_',''))+1 logger.info('Restoring epoch: {}'.format(restore_epoch)) logger.info('Begin training and evaluation') # initialize the early_stopping object early_stopping = EarlyStopping(patience=25, verbose=True, delta=0.0001, folder=params.model_dir) if os.path.exists(os.path.join(params.model_dir, 'metrics_test_best_weights.json')): with open(os.path.join(params.model_dir, 'metrics_test_best_weights.json')) as json_file: best_test_ND = json.load(json_file)['ND'] early_stopping.best_score = best_test_ND else: best_test_ND = float('inf') early_stopping.best_score = best_test_ND train_len = len(train_loader) ND_summary = np.zeros(params.num_epochs) loss_summary = np.zeros((train_len * params.num_epochs)) for epoch in range(restore_epoch, params.num_epochs): logger.info('Epoch {}/{}'.format(epoch + 1, params.num_epochs)) loss_summary[epoch * train_len:(epoch + 1) * train_len] = train(model, optimizer, loss_fn, train_loader, test_loader, params, epoch) test_metrics = evaluate(model, loss_fn, test_loader, params, epoch, sample=args.sampling) # if test_metrics['ND'] == float('nan'): # test_metrics['ND'] = 1000 # print('NAN ') # elif test_metrics['ND'] == np.nan: # print('NAN ') # test_metrics['ND'] = 1000 ND_summary[epoch] = test_metrics['ND'] ##################################'ND' is_best = ND_summary[epoch] <= best_test_ND # Save weights utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict()}, epoch=epoch, is_best=is_best, checkpoint=params.model_dir) if is_best: logger.info('- Found new best ND') ############# 'ND' best_test_ND = ND_summary[epoch] best_json_path = os.path.join(params.model_dir, 'metrics_test_best_weights.json') utils.save_dict_to_json(test_metrics, best_json_path) logger.info('Current Best ND is: %.5f' % best_test_ND) ## 'ND' utils.plot_all_epoch(ND_summary[:epoch + 1], args.dataset + '_ND', params.plot_dir) utils.plot_all_epoch(loss_summary[:(epoch + 1) * train_len], args.dataset + '_loss', params.plot_dir) last_json_path = os.path.join(params.model_dir, 'metrics_test_last_weights.json') utils.save_dict_to_json(test_metrics, last_json_path) # early_stopping needs the validation loss to check if it has decresed, # and if it has, it will make a checkpoint of the current model logger.info('ND : %.5f ' % test_metrics['ND']) early_stopping(test_metrics['ND'], model) if early_stopping.early_stop: logger.info('Early stopping') break # # load the last checkpoint with the best model # model.load_state_dict(torch.load('checkpoint.pt')) if args.save_best: f = open('./param_search.txt', 'w') f.write('-----------\n') list_of_params = args.search_params.split(',') print_params = '' for param in list_of_params: param_value = getattr(params, param) print_params += f'{param}: {param_value:.2f}' print_params = print_params[:-1] f.write(print_params + '\n') f.write('Best ND: ' + str(best_test_ND) + '\n') logger.info(print_params) logger.info(f'Best ND: {best_test_ND}') f.close() utils.plot_all_epoch(ND_summary, print_params + '_ND', location=params.plot_dir) utils.plot_all_epoch(loss_summary, print_params + '_loss', location=params.plot_dir)