def main(args):

    dataset_name = args.dataset
    model_name = args.model
    n_inner_iter = args.adaptation_steps
    batch_size = args.batch_size
    save_model_file = args.save_model_file
    load_model_file = args.load_model_file
    lower_trial = args.lower_trial
    upper_trial = args.upper_trial
    is_test = args.is_test
    stopping_patience = args.stopping_patience
    epochs = args.epochs
    fast_lr = args.learning_rate
    slow_lr = args.meta_learning_rate
    noise_level = args.noise_level
    noise_type = args.noise_type
    resume = args.resume

    first_order = False
    inner_loop_grad_clip = 20
    task_size = 50
    output_dim = 1
    checkpoint_freq = 10
    horizon = 10
    ##test

    meta_info = {
        "POLLUTION": [5, 50, 14],
        "HR": [32, 50, 13],
        "BATTERY": [20, 50, 3]
    }

    assert model_name in ("FCN", "LSTM"), "Model was not correctly specified"
    assert dataset_name in ("POLLUTION", "HR", "BATTERY")

    window_size, task_size, input_dim = meta_info[dataset_name]

    grid = [0., noise_level]
    output_directory = "output/"

    train_data_ML = pickle.load(
        open(
            "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))
    validation_data_ML = pickle.load(
        open(
            "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" +
            str(task_size) + "-ML.pickle", "rb"))
    test_data_ML = pickle.load(
        open(
            "../../Data/TEST-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))

    for trial in range(lower_trial, upper_trial):

        output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MAML/" + str(
            trial) + "/"
        save_model_file_ = output_directory + save_model_file
        save_model_file_encoder = output_directory + "encoder_" + save_model_file
        load_model_file_ = output_directory + load_model_file
        checkpoint_file = output_directory + "checkpoint_" + save_model_file.split(
            ".")[0]

        try:
            os.mkdir(output_directory)
        except OSError as error:
            print(error)

        with open(output_directory + "/results2.txt", "a+") as f:
            f.write("Learning rate :%f \n" % fast_lr)
            f.write("Meta-learning rate: %f \n" % slow_lr)
            f.write("Adaptation steps: %f \n" % n_inner_iter)
            f.write("Noise level: %f \n" % noise_level)

        if model_name == "LSTM":
            model = LSTMModel(batch_size=batch_size,
                              seq_len=window_size,
                              input_dim=input_dim,
                              n_layers=2,
                              hidden_dim=120,
                              output_dim=output_dim)
            model2 = LinearModel(120, 1)
        optimizer = torch.optim.Adam(list(model.parameters()) +
                                     list(model2.parameters()),
                                     lr=slow_lr)
        loss_func = mae
        #loss_func = nn.SmoothL1Loss()
        #loss_func = nn.MSELoss()
        initial_epoch = 0

        #torch.backends.cudnn.enabled = False

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        meta_learner = MetaLearner(model2, optimizer, fast_lr, loss_func,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)
        model.to(device)

        early_stopping = EarlyStopping(patience=stopping_patience,
                                       model_file=save_model_file_encoder,
                                       verbose=True)
        early_stopping2 = EarlyStopping(patience=stopping_patience,
                                        model_file=save_model_file_,
                                        verbose=True)

        if resume:
            checkpoint = torch.load(checkpoint_file)
            model.load_state_dict(checkpoint["model"])
            meta_learner.load_state_dict(checkpoint["meta_learner"])
            initial_epoch = checkpoint["epoch"]
            best_score = checkpoint["best_score"]
            counter = checkpoint["counter_stopping"]

            early_stopping.best_score = best_score
            early_stopping2.best_score = best_score

            early_stopping.counter = counter
            early_stopping2.counter = counter

        total_tasks, task_size, window_size, input_dim = train_data_ML.x.shape
        accum_mean = 0.0

        for epoch in range(initial_epoch, epochs):

            model.zero_grad()
            meta_learner._model.zero_grad()

            #train
            batch_idx = np.random.randint(0, total_tasks - 1, batch_size)

            #for batch_idx in range(0, total_tasks-1, batch_size):

            x_spt, y_spt = train_data_ML[batch_idx]
            x_qry, y_qry = train_data_ML[batch_idx + 1]

            x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
            x_qry = to_torch(x_qry)
            y_qry = to_torch(y_qry)

            # data augmentation
            epsilon = grid[np.random.randint(0, len(grid))]

            if noise_type == "additive":
                y_spt = y_spt + epsilon
                y_qry = y_qry + epsilon
            else:
                y_spt = y_spt * (1 + epsilon)
                y_qry = y_qry * (1 + epsilon)

            train_tasks = [
                Task(model.encoder(x_spt[i]), y_spt[i])
                for i in range(x_spt.shape[0])
            ]
            val_tasks = [
                Task(model.encoder(x_qry[i]), y_qry[i])
                for i in range(x_qry.shape[0])
            ]

            adapted_params = meta_learner.adapt(train_tasks)
            mean_loss = meta_learner.step(adapted_params,
                                          val_tasks,
                                          is_training=True)
            #accum_mean += mean_loss.cpu().detach().numpy()

            #progressBar(batch_idx, total_tasks, 100)

            #print(accum_mean/(batch_idx+1))

            #test

            val_error = test(validation_data_ML, meta_learner, model, device,
                             noise_level)
            test_error = test(test_data_ML, meta_learner, model, device, 0.0)
            print("Epoch:", epoch)
            print("Val error:", val_error)
            print("Test error:", test_error)

            early_stopping(val_error, model)
            early_stopping2(val_error, meta_learner)

            #checkpointing
            if epochs % checkpoint_freq == 0:
                torch.save(
                    {
                        "epoch": epoch,
                        "model": model.state_dict(),
                        "meta_learner": meta_learner.state_dict(),
                        "best_score": early_stopping2.best_score,
                        "counter_stopping": early_stopping2.counter
                    }, checkpoint_file)

            if early_stopping.early_stop:
                print("Early stopping")
                break

        print("hallo")
        model.load_state_dict(torch.load(save_model_file_encoder))
        model2.load_state_dict(
            torch.load(save_model_file_)["model_state_dict"])
        meta_learner = MetaLearner(model2, optimizer, fast_lr, loss_func,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)

        validation_error = test(validation_data_ML,
                                meta_learner,
                                model,
                                device,
                                noise_level=0.0)
        test_error = test(test_data_ML,
                          meta_learner,
                          model,
                          device,
                          noise_level=0.0)

        validation_error_h1 = test(validation_data_ML,
                                   meta_learner,
                                   model,
                                   device,
                                   noise_level=0.0,
                                   horizon=1)
        test_error_h1 = test(test_data_ML,
                             meta_learner,
                             model,
                             device,
                             noise_level=0.0,
                             horizon=1)

        model.load_state_dict(torch.load(save_model_file_encoder))
        model2.load_state_dict(
            torch.load(save_model_file_)["model_state_dict"])
        meta_learner2 = MetaLearner(model2, optimizer, fast_lr, loss_func,
                                    first_order, 0, inner_loop_grad_clip,
                                    device)

        validation_error_h0 = test(validation_data_ML,
                                   meta_learner2,
                                   model,
                                   device,
                                   noise_level=0.0,
                                   horizon=1)
        test_error_h0 = test(test_data_ML,
                             meta_learner2,
                             model,
                             device,
                             noise_level=0.0,
                             horizon=1)

        model.load_state_dict(torch.load(save_model_file_encoder))
        model2.load_state_dict(
            torch.load(save_model_file_)["model_state_dict"])
        meta_learner2 = MetaLearner(model2, optimizer, fast_lr, loss_func,
                                    first_order, n_inner_iter,
                                    inner_loop_grad_clip, device)
        validation_error_mae = test(validation_data_ML, meta_learner2, model,
                                    device, 0.0)
        test_error_mae = test(test_data_ML, meta_learner2, model, device, 0.0)
        print("test_error_mae", test_error_mae)

        with open(output_directory + "/results2.txt", "a+") as f:
            f.write("Test error: %f \n" % test_error)
            f.write("Validation error: %f \n" % validation_error)
            f.write("Test error h1: %f \n" % test_error_h1)
            f.write("Validation error h1: %f \n" % validation_error_h1)
            f.write("Test error h0: %f \n" % test_error_h0)
            f.write("Validation error h0: %f \n" % validation_error_h0)
            f.write("Test error mae: %f \n" % test_error_mae)
            f.write("Validation error mae: %f \n" % validation_error_mae)

        print(test_error)
        print(validation_error)
Esempio n. 2
0
def train_and_evaluate(model: nn.Module,
                       train_loader: DataLoader,
                       test_loader: DataLoader,
                       optimizer: optim, loss_fn,
                       params: utils.Params,
                       restore_file: str = None) -> None:
    '''Train the model and evaluate every epoch.
    Args:
        model: (torch.nn.Module) the Deep AR model
        train_loader: load train data and labels
        test_loader: load test data and labels
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes outputs and labels per timestep, and then computes the loss for the batch
        params: (Params) hyperparameters
        restore_file: (string) optional- name of file to restore from (without its extension .pth.tar)
    '''
    # reload weights from restore_file if specified
    restore_epoch = 0
    if restore_file is not None:
        restore_path = os.path.join(params.model_dir, restore_file + '.pth.tar')
        logger.info('Restoring parameters from {}'.format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)
        restore_epoch = int(restore_file[-2:].replace('_',''))+1
    logger.info('Restoring epoch: {}'.format(restore_epoch))
    logger.info('Begin training and evaluation')
    
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=25, verbose=True, delta=0.0001, folder=params.model_dir)
    
    if os.path.exists(os.path.join(params.model_dir, 'metrics_test_best_weights.json')):
        with open(os.path.join(params.model_dir, 'metrics_test_best_weights.json')) as json_file:
            best_test_ND = json.load(json_file)['ND']
            early_stopping.best_score = best_test_ND
    else:
        best_test_ND = float('inf')
        early_stopping.best_score = best_test_ND
    
    train_len = len(train_loader)
    ND_summary = np.zeros(params.num_epochs)
    loss_summary = np.zeros((train_len * params.num_epochs))
    
    for epoch in range(restore_epoch, params.num_epochs):
        logger.info('Epoch {}/{}'.format(epoch + 1, params.num_epochs))
        loss_summary[epoch * train_len:(epoch + 1) * train_len] = train(model, optimizer, loss_fn, train_loader,
                                                                        test_loader, params, epoch)
        test_metrics = evaluate(model, loss_fn, test_loader, params, epoch, sample=args.sampling)
#         if test_metrics['ND'] == float('nan'):
#             test_metrics['ND'] = 1000
#             print('NAN ')

#         elif test_metrics['ND'] == np.nan:
#             print('NAN ')
#             test_metrics['ND'] = 1000
        
        ND_summary[epoch] = test_metrics['ND'] ##################################'ND'
        is_best = ND_summary[epoch] <= best_test_ND

        # Save weights
        utils.save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model.state_dict(),
                               'optim_dict': optimizer.state_dict()},
                              epoch=epoch,
                              is_best=is_best,
                              checkpoint=params.model_dir)

        if is_best:
            logger.info('- Found new best ND') ############# 'ND'
            best_test_ND = ND_summary[epoch]
            best_json_path = os.path.join(params.model_dir, 'metrics_test_best_weights.json')
            utils.save_dict_to_json(test_metrics, best_json_path)

        logger.info('Current Best ND is: %.5f' % best_test_ND) ## 'ND'

        utils.plot_all_epoch(ND_summary[:epoch + 1], args.dataset + '_ND', params.plot_dir)
        utils.plot_all_epoch(loss_summary[:(epoch + 1) * train_len], args.dataset + '_loss', params.plot_dir)

        last_json_path = os.path.join(params.model_dir, 'metrics_test_last_weights.json')
        utils.save_dict_to_json(test_metrics, last_json_path)
        
        
        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        logger.info('ND : %.5f ' % test_metrics['ND'])
        early_stopping(test_metrics['ND'], model)
        
        if early_stopping.early_stop:
            logger.info('Early stopping')
            break
        
#     # load the last checkpoint with the best model
#     model.load_state_dict(torch.load('checkpoint.pt'))

    if args.save_best:
        f = open('./param_search.txt', 'w')
        f.write('-----------\n')
        list_of_params = args.search_params.split(',')
        print_params = ''
        for param in list_of_params:
            param_value = getattr(params, param)
            print_params += f'{param}: {param_value:.2f}'
        print_params = print_params[:-1]
        f.write(print_params + '\n')
        f.write('Best ND: ' + str(best_test_ND) + '\n')
        logger.info(print_params)
        logger.info(f'Best ND: {best_test_ND}')
        f.close()
        utils.plot_all_epoch(ND_summary, print_params + '_ND', location=params.plot_dir)
        utils.plot_all_epoch(loss_summary, print_params + '_loss', location=params.plot_dir)