def test(test_data_ML, meta_learner, model, device, horizon=10):

    total_tasks_test = len(test_data_ML)
    task_size = test_data_ML.x.shape[-3]
    input_dim = test_data_ML.x.shape[-1]
    window_size = test_data_ML.x.shape[-2]
    output_dim = test_data_ML.y.shape[-1]

    accum_error = 0.0
    count = 0

    for task in range(0, (total_tasks_test - horizon - 1),
                      total_tasks_test // 100):

        x_spt, y_spt = test_data_ML[task]
        x_qry = test_data_ML.x[(task + 1):(task + 1 + horizon)].reshape(
            -1, window_size, input_dim)
        y_qry = test_data_ML.y[(task + 1):(task + 1 + horizon)].reshape(
            -1, output_dim)

        x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
        x_qry = to_torch(x_qry)
        y_qry = to_torch(y_qry)

        train_task = [Task(model.encoder(x_spt), y_spt)]
        val_task = [Task(model.encoder(x_qry), y_qry)]

        adapted_params = meta_learner.adapt(train_task)
        mean_loss = meta_learner.step(adapted_params, val_task, is_training=0)

        count += 1
        accum_error += mean_loss.cpu().detach().numpy()

    return accum_error / count
Beispiel #2
0
def test(data_ML, multimodal_learner, meta_learner, task_data, horizon=10):
    total_tasks = len(data_ML)
    task_size = data_ML.x.shape[-3]
    input_dim = data_ML.x.shape[-1]
    window_size = data_ML.x.shape[-2]
    output_dim = data_ML.y.shape[-1]

    accum_error = 0.0
    count = 0

    for task_id in range(0, (total_tasks - horizon - 1), total_tasks // 100):
        x_spt, y_spt = data_ML[task_id]
        x_qry = data_ML.x[(task_id + 1):(task_id + 1 + horizon)].reshape(
            -1, window_size, input_dim)
        y_qry = data_ML.y[(task_id + 1):(task_id + 1 + horizon)].reshape(
            -1, output_dim)
        task = task_data[task_id:task_id + 1].cuda()

        x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
        x_qry = to_torch(x_qry)
        y_qry = to_torch(y_qry)
        x_spt_encod, _ = multimodal_learner(x_spt, task, output_encoding=True)
        x_qry_encod, _ = multimodal_learner(x_qry, task, output_encoding=True)

        train_task = [Task(x_spt_encod, y_spt)]
        val_task = [Task(x_qry_encod, y_qry)]

        adapted_params = meta_learner.adapt(train_task)
        mean_loss = meta_learner.step(adapted_params, val_task, is_training=0)

        count += 1
        accum_error += mean_loss.data

    return accum_error.cpu().detach().numpy() / count
def test(test_data_ML,
         meta_learner,
         model,
         device,
         noise_level,
         noise_type="additive",
         horizon=10):

    total_tasks_test = len(test_data_ML)
    task_size = test_data_ML.x.shape[-3]
    input_dim = test_data_ML.x.shape[-1]
    window_size = test_data_ML.x.shape[-2]
    output_dim = test_data_ML.y.shape[-1]
    grid = [0., noise_level]

    accum_error = 0.0
    count = 0

    for task in range(0, (total_tasks_test - horizon - 1),
                      total_tasks_test // 100):

        x_spt, y_spt = test_data_ML[task]
        x_qry = test_data_ML.x[(task + 1):(task + 1 + horizon)].reshape(
            -1, window_size, input_dim)
        y_qry = test_data_ML.y[(task + 1):(task + 1 + horizon)].reshape(
            -1, output_dim)

        x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
        x_qry = to_torch(x_qry)
        y_qry = to_torch(y_qry)

        epsilon = grid[np.random.randint(0, len(grid))]

        if noise_type == "additive":
            y_spt = y_spt + epsilon
            y_qry = y_qry + epsilon

        else:
            y_spt = y_spt * (1 + epsilon)
            y_qry = y_qry * (1 + epsilon)

        train_task = [Task(model.encoder(x_spt), y_spt)]
        val_task = [Task(model.encoder(x_qry), y_qry)]

        adapted_params = meta_learner.adapt(train_task)
        mean_loss = meta_learner.step(adapted_params, val_task, is_training=0)

        count += 1
        accum_error += mean_loss.cpu().detach().numpy()

    return accum_error / count
def main(args):

    dataset_name = args.dataset
    model_name = args.model
    n_inner_iter = args.adaptation_steps
    batch_size = args.batch_size
    save_model_file = args.save_model_file
    load_model_file = args.load_model_file
    lower_trial = args.lower_trial
    upper_trial = args.upper_trial
    is_test = args.is_test
    stopping_patience = args.stopping_patience
    epochs = args.epochs
    fast_lr = args.learning_rate
    slow_lr = args.meta_learning_rate
    noise_level = args.noise_level
    noise_type = args.noise_type
    resume = args.resume

    first_order = False
    inner_loop_grad_clip = 20
    task_size = 50
    output_dim = 1
    checkpoint_freq = 10
    horizon = 10
    ##test

    meta_info = {
        "POLLUTION": [5, 50, 14],
        "HR": [32, 50, 13],
        "BATTERY": [20, 50, 3]
    }

    assert model_name in ("FCN", "LSTM"), "Model was not correctly specified"
    assert dataset_name in ("POLLUTION", "HR", "BATTERY")

    window_size, task_size, input_dim = meta_info[dataset_name]

    grid = [0., noise_level]
    output_directory = "output/"

    train_data_ML = pickle.load(
        open(
            "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))
    validation_data_ML = pickle.load(
        open(
            "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" +
            str(task_size) + "-ML.pickle", "rb"))
    test_data_ML = pickle.load(
        open(
            "../../Data/TEST-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))

    for trial in range(lower_trial, upper_trial):

        output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MAML/" + str(
            trial) + "/"
        save_model_file_ = output_directory + save_model_file
        save_model_file_encoder = output_directory + "encoder_" + save_model_file
        load_model_file_ = output_directory + load_model_file
        checkpoint_file = output_directory + "checkpoint_" + save_model_file.split(
            ".")[0]

        try:
            os.mkdir(output_directory)
        except OSError as error:
            print(error)

        with open(output_directory + "/results2.txt", "a+") as f:
            f.write("Learning rate :%f \n" % fast_lr)
            f.write("Meta-learning rate: %f \n" % slow_lr)
            f.write("Adaptation steps: %f \n" % n_inner_iter)
            f.write("Noise level: %f \n" % noise_level)

        if model_name == "LSTM":
            model = LSTMModel(batch_size=batch_size,
                              seq_len=window_size,
                              input_dim=input_dim,
                              n_layers=2,
                              hidden_dim=120,
                              output_dim=output_dim)
            model2 = LinearModel(120, 1)
        optimizer = torch.optim.Adam(list(model.parameters()) +
                                     list(model2.parameters()),
                                     lr=slow_lr)
        loss_func = mae
        #loss_func = nn.SmoothL1Loss()
        #loss_func = nn.MSELoss()
        initial_epoch = 0

        #torch.backends.cudnn.enabled = False

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        meta_learner = MetaLearner(model2, optimizer, fast_lr, loss_func,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)
        model.to(device)

        early_stopping = EarlyStopping(patience=stopping_patience,
                                       model_file=save_model_file_encoder,
                                       verbose=True)
        early_stopping2 = EarlyStopping(patience=stopping_patience,
                                        model_file=save_model_file_,
                                        verbose=True)

        if resume:
            checkpoint = torch.load(checkpoint_file)
            model.load_state_dict(checkpoint["model"])
            meta_learner.load_state_dict(checkpoint["meta_learner"])
            initial_epoch = checkpoint["epoch"]
            best_score = checkpoint["best_score"]
            counter = checkpoint["counter_stopping"]

            early_stopping.best_score = best_score
            early_stopping2.best_score = best_score

            early_stopping.counter = counter
            early_stopping2.counter = counter

        total_tasks, task_size, window_size, input_dim = train_data_ML.x.shape
        accum_mean = 0.0

        for epoch in range(initial_epoch, epochs):

            model.zero_grad()
            meta_learner._model.zero_grad()

            #train
            batch_idx = np.random.randint(0, total_tasks - 1, batch_size)

            #for batch_idx in range(0, total_tasks-1, batch_size):

            x_spt, y_spt = train_data_ML[batch_idx]
            x_qry, y_qry = train_data_ML[batch_idx + 1]

            x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
            x_qry = to_torch(x_qry)
            y_qry = to_torch(y_qry)

            # data augmentation
            epsilon = grid[np.random.randint(0, len(grid))]

            if noise_type == "additive":
                y_spt = y_spt + epsilon
                y_qry = y_qry + epsilon
            else:
                y_spt = y_spt * (1 + epsilon)
                y_qry = y_qry * (1 + epsilon)

            train_tasks = [
                Task(model.encoder(x_spt[i]), y_spt[i])
                for i in range(x_spt.shape[0])
            ]
            val_tasks = [
                Task(model.encoder(x_qry[i]), y_qry[i])
                for i in range(x_qry.shape[0])
            ]

            adapted_params = meta_learner.adapt(train_tasks)
            mean_loss = meta_learner.step(adapted_params,
                                          val_tasks,
                                          is_training=True)
            #accum_mean += mean_loss.cpu().detach().numpy()

            #progressBar(batch_idx, total_tasks, 100)

            #print(accum_mean/(batch_idx+1))

            #test

            val_error = test(validation_data_ML, meta_learner, model, device,
                             noise_level)
            test_error = test(test_data_ML, meta_learner, model, device, 0.0)
            print("Epoch:", epoch)
            print("Val error:", val_error)
            print("Test error:", test_error)

            early_stopping(val_error, model)
            early_stopping2(val_error, meta_learner)

            #checkpointing
            if epochs % checkpoint_freq == 0:
                torch.save(
                    {
                        "epoch": epoch,
                        "model": model.state_dict(),
                        "meta_learner": meta_learner.state_dict(),
                        "best_score": early_stopping2.best_score,
                        "counter_stopping": early_stopping2.counter
                    }, checkpoint_file)

            if early_stopping.early_stop:
                print("Early stopping")
                break

        print("hallo")
        model.load_state_dict(torch.load(save_model_file_encoder))
        model2.load_state_dict(
            torch.load(save_model_file_)["model_state_dict"])
        meta_learner = MetaLearner(model2, optimizer, fast_lr, loss_func,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)

        validation_error = test(validation_data_ML,
                                meta_learner,
                                model,
                                device,
                                noise_level=0.0)
        test_error = test(test_data_ML,
                          meta_learner,
                          model,
                          device,
                          noise_level=0.0)

        validation_error_h1 = test(validation_data_ML,
                                   meta_learner,
                                   model,
                                   device,
                                   noise_level=0.0,
                                   horizon=1)
        test_error_h1 = test(test_data_ML,
                             meta_learner,
                             model,
                             device,
                             noise_level=0.0,
                             horizon=1)

        model.load_state_dict(torch.load(save_model_file_encoder))
        model2.load_state_dict(
            torch.load(save_model_file_)["model_state_dict"])
        meta_learner2 = MetaLearner(model2, optimizer, fast_lr, loss_func,
                                    first_order, 0, inner_loop_grad_clip,
                                    device)

        validation_error_h0 = test(validation_data_ML,
                                   meta_learner2,
                                   model,
                                   device,
                                   noise_level=0.0,
                                   horizon=1)
        test_error_h0 = test(test_data_ML,
                             meta_learner2,
                             model,
                             device,
                             noise_level=0.0,
                             horizon=1)

        model.load_state_dict(torch.load(save_model_file_encoder))
        model2.load_state_dict(
            torch.load(save_model_file_)["model_state_dict"])
        meta_learner2 = MetaLearner(model2, optimizer, fast_lr, loss_func,
                                    first_order, n_inner_iter,
                                    inner_loop_grad_clip, device)
        validation_error_mae = test(validation_data_ML, meta_learner2, model,
                                    device, 0.0)
        test_error_mae = test(test_data_ML, meta_learner2, model, device, 0.0)
        print("test_error_mae", test_error_mae)

        with open(output_directory + "/results2.txt", "a+") as f:
            f.write("Test error: %f \n" % test_error)
            f.write("Validation error: %f \n" % validation_error)
            f.write("Test error h1: %f \n" % test_error_h1)
            f.write("Validation error h1: %f \n" % validation_error_h1)
            f.write("Test error h0: %f \n" % test_error_h0)
            f.write("Validation error h0: %f \n" % validation_error_h0)
            f.write("Test error mae: %f \n" % test_error_mae)
            f.write("Validation error mae: %f \n" % validation_error_mae)

        print(test_error)
        print(validation_error)
Beispiel #5
0
def main(args):
    dataset_name = args.dataset
    model_name = args.model
    n_inner_iter = args.adaptation_steps
    meta_learning_rate = args.meta_learning_rate
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    save_model_file = args.save_model_file
    load_model_file = args.load_model_file
    lower_trial = args.lower_trial
    upper_trial = args.upper_trial
    task_size = args.task_size
    noise_level = args.noise_level
    noise_type = args.noise_type
    epochs = args.epochs
    loss_fcn_str = args.loss
    modulate_task_net = args.modulate_task_net
    weight_vrae = args.weight_vrae
    stopping_patience = args.stopping_patience

    meta_info = {"POLLUTION": [5, 14], "HR": [32, 13], "BATTERY": [20, 3]}

    assert model_name in ("FCN", "LSTM"), "Model was not correctly specified"
    assert dataset_name in ("POLLUTION", "HR", "BATTERY")

    window_size, input_dim = meta_info[dataset_name]

    grid = [0., noise_level]

    train_data_ML = pickle.load(
        open(
            "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))
    validation_data_ML = pickle.load(
        open(
            "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" +
            str(task_size) + "-ML.pickle", "rb"))
    test_data_ML = pickle.load(
        open(
            "../../Data/TEST-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))

    total_tasks = len(train_data_ML)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    loss_fn = mae if loss_fcn_str == "MAE" else nn.SmoothL1Loss()

    ##multimodal learner parameters
    # paramters wto increase capactiy of the model
    n_layers_task_net = 2
    n_layers_task_encoder = 2
    n_layers_task_decoder = 2

    hidden_dim_task_net = 120
    hidden_dim_encoder = 120
    hidden_dim_decoder = 120

    # fixed values
    input_dim_task_net = input_dim
    input_dim_task_encoder = input_dim + 1
    output_dim_task_net = 1
    output_dim_task_decoder = input_dim + 1

    first_order = False
    inner_loop_grad_clip = 20

    for trial in range(lower_trial, upper_trial):

        output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MMAML/" + str(
            trial) + "/"
        save_model_file_ = output_directory + save_model_file
        save_model_file_encoder = output_directory + "encoder_" + save_model_file
        load_model_file_ = output_directory + load_model_file
        checkpoint_file = output_directory + "checkpoint_" + save_model_file.split(
            ".")[0]

        writer = SummaryWriter()

        try:
            os.mkdir(output_directory)
        except OSError as error:
            print(error)

        task_net = LSTMModel(batch_size=batch_size,
                             seq_len=window_size,
                             input_dim=input_dim_task_net,
                             n_layers=n_layers_task_net,
                             hidden_dim=hidden_dim_task_net,
                             output_dim=output_dim_task_net)

        task_encoder = LSTMModel(batch_size=batch_size,
                                 seq_len=task_size,
                                 input_dim=input_dim_task_encoder,
                                 n_layers=n_layers_task_encoder,
                                 hidden_dim=hidden_dim_encoder,
                                 output_dim=1)

        task_decoder = LSTMDecoder(batch_size=1,
                                   n_layers=n_layers_task_decoder,
                                   seq_len=task_size,
                                   output_dim=output_dim_task_decoder,
                                   hidden_dim=hidden_dim_encoder,
                                   latent_dim=hidden_dim_decoder,
                                   device=device)

        lmbd = Lambda(hidden_dim_encoder, hidden_dim_task_net)

        multimodal_learner = MultimodalLearner(task_net, task_encoder,
                                               task_decoder, lmbd,
                                               modulate_task_net)
        multimodal_learner.to(device)

        output_layer = LinearModel(120, 1)
        opt = torch.optim.Adam(list(multimodal_learner.parameters()) +
                               list(output_layer.parameters()),
                               lr=meta_learning_rate)

        meta_learner = MetaLearner(output_layer, opt, learning_rate, loss_fn,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)

        early_stopping = EarlyStopping(patience=stopping_patience,
                                       model_file=save_model_file_,
                                       verbose=True)
        early_stopping_encoder = EarlyStopping(
            patience=stopping_patience,
            model_file=save_model_file_encoder,
            verbose=True)

        task_data_train = torch.FloatTensor(
            get_task_encoder_input(train_data_ML))
        task_data_validation = torch.FloatTensor(
            get_task_encoder_input(validation_data_ML))
        task_data_test = torch.FloatTensor(
            get_task_encoder_input(test_data_ML))

        val_loss_hist = []
        test_loss_hist = []

        for epoch in range(epochs):

            multimodal_learner.train()

            batch_idx = np.random.randint(0, total_tasks - 1, batch_size)
            task = task_data_train[batch_idx].cuda()

            x_spt, y_spt = train_data_ML[batch_idx]
            x_qry, y_qry = train_data_ML[batch_idx + 1]

            x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
            x_qry = to_torch(x_qry)
            y_qry = to_torch(y_qry)

            # data augmentation
            epsilon = grid[np.random.randint(0, len(grid))]

            if noise_type == "additive":
                y_spt = y_spt + epsilon
                y_qry = y_qry + epsilon
            else:
                y_spt = y_spt * (1 + epsilon)
                y_qry = y_qry * (1 + epsilon)

            x_spt_encodings = []
            x_qry_encodings = []
            vrae_loss_accum = 0.0
            for i in range(batch_size):
                x_spt_encoding, (vrae_loss, kl_loss,
                                 rec_loss) = multimodal_learner(
                                     x_spt[i],
                                     task[i:i + 1],
                                     output_encoding=True)
                x_spt_encodings.append(x_spt_encoding)
                vrae_loss_accum += vrae_loss

                x_qry_encoding, _ = multimodal_learner(x_qry[i],
                                                       task[i:i + 1],
                                                       output_encoding=True)
                x_qry_encodings.append(x_qry_encoding)

            train_tasks = [
                Task(x_spt_encodings[i], y_spt[i])
                for i in range(x_spt.shape[0])
            ]
            val_tasks = [
                Task(x_qry_encodings[i], y_qry[i])
                for i in range(x_qry.shape[0])
            ]

            # print(vrae_loss)

            adapted_params = meta_learner.adapt(train_tasks)
            mean_loss = meta_learner.step(adapted_params,
                                          val_tasks,
                                          is_training=True,
                                          additional_loss_term=weight_vrae *
                                          vrae_loss_accum / batch_size)

            ##plotting grad of output layer
            for tag, parm in output_layer.linear.named_parameters():
                writer.add_histogram("Grads_output_layer_" + tag,
                                     parm.grad.data.cpu().numpy(), epoch)

            multimodal_learner.eval()
            val_loss = test(validation_data_ML, multimodal_learner,
                            meta_learner, task_data_validation)
            test_loss = test(test_data_ML, multimodal_learner, meta_learner,
                             task_data_test)

            print("Epoch:", epoch)
            print("Train loss:", mean_loss)
            print("Val error:", val_loss)
            print("Test error:", test_loss)

            early_stopping(val_loss, meta_learner)
            early_stopping_encoder(val_loss, multimodal_learner)

            val_loss_hist.append(val_loss)
            test_loss_hist.append(test_loss)

            if early_stopping.early_stop:
                print("Early stopping")
                break

            writer.add_scalar("Loss/train",
                              mean_loss.cpu().detach().numpy(), epoch)
            writer.add_scalar("Loss/val", val_loss, epoch)
            writer.add_scalar("Loss/test", test_loss, epoch)

        multimodal_learner.load_state_dict(torch.load(save_model_file_encoder))
        output_layer.load_state_dict(
            torch.load(save_model_file_)["model_state_dict"])
        meta_learner = MetaLearner(output_layer, opt, learning_rate, loss_fn,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)

        val_loss = test(validation_data_ML, multimodal_learner, meta_learner,
                        task_data_validation)
        test_loss = test(test_data_ML, multimodal_learner, meta_learner,
                         task_data_test)

        with open(output_directory + "/results3.txt", "a+") as f:
            f.write("Dataset :%s \n" % dataset_name)
            f.write("Test error: %f \n" % test_loss)
            f.write("Val error: %f \n" % val_loss)
            f.write("\n")

        writer.add_hparams(
            {
                "fast_lr": learning_rate,
                "slow_lr": meta_learning_rate,
                "adaption_steps": n_inner_iter,
                "patience": stopping_patience,
                "weight_vrae": weight_vrae,
                "noise_level": noise_level,
                "dataset": dataset_name,
                "trial": trial
            }, {
                "val_loss": val_loss,
                "test_loss": test_loss
            })
Beispiel #6
0
def main(args):

    meta_info = {
        "POLLUTION": [5, 50, 14],
        "HR": [32, 50, 13],
        "BATTERY": [20, 50, 3]
    }

    output_directory = "output/"
    verbose = True
    batch_size = 64
    freeze_model_flag = True

    params = {'batch_size': batch_size, 'shuffle': True, 'num_workers': 0}

    dataset_name = args.dataset
    model_name = args.model
    learning_rate = args.learning_rate
    save_model_file = args.save_model_file
    load_model_file = args.load_model_file
    lower_trial = args.lower_trial
    upper_trial = args.upper_trial
    is_test = args.is_test
    epochs = args.epochs
    experiment_id = args.experiment_id
    adaptation_steps = args.adaptation_steps

    assert model_name in ("FCN", "LSTM"), "Model was not correctly specified"
    assert dataset_name in ("POLLUTION", "HR", "BATTERY")

    window_size, task_size, input_dim = meta_info[dataset_name]
    batch_size = 64

    train_data = pickle.load(
        open(
            "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-NOML.pickle", "rb"))
    train_data_ML = pickle.load(
        open(
            "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))
    validation_data = pickle.load(
        open(
            "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" +
            str(task_size) + "-NOML.pickle", "rb"))
    validation_data_ML = pickle.load(
        open(
            "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" +
            str(task_size) + "-ML.pickle", "rb"))
    test_data = pickle.load(
        open(
            "../../Data/TEST-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-NOML.pickle", "rb"))
    test_data_ML = pickle.load(
        open(
            "../../Data/TEST-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))

    if is_test == 0:
        test_data = validation_data

    train_idx, val_idx, test_idx = split_idx_50_50(
        test_data.file_idx) if is_test else split_idx_50_50(
            validation_data.file_idx)
    n_domains_in_test = np.max(test_data.file_idx) + 1

    test_loss_list = []
    initial_test_loss_list = []

    trials_loss_list = []

    #trial = 0
    for trial in range(lower_trial, upper_trial):

        output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MAML/" + str(
            trial) + "/"

        #save_model_file_ = output_directory + "encoder_"+save_model_file
        #save_model_file_2 = output_directory + save_model_file
        save_model_file_ = output_directory + experiment_id + "_encoder_model.pt"
        save_model_file_2 = output_directory + experiment_id + "_model.pt"
        load_model_file_ = output_directory + load_model_file

        model = LSTMModel(batch_size=batch_size,
                          seq_len=window_size,
                          input_dim=input_dim,
                          n_layers=2,
                          hidden_dim=120,
                          output_dim=1)
        model2 = nn.Linear(120, 1)

        model.cuda()
        model2.cuda()

        maml = l2l.algorithms.MAML(model2, lr=learning_rate, first_order=False)
        model.load_state_dict(torch.load(save_model_file_))
        maml.load_state_dict(torch.load(save_model_file_2))

        n_domains_in_test = np.max(test_data.file_idx) + 1

        error_list = []

        y_list = []

        for domain in range(n_domains_in_test):
            x_test = test_data.x
            y_test = test_data.y

            temp_train_data = SimpleDataset(
                x=np.concatenate([
                    x_test[np.concatenate([train_idx[domain],
                                           val_idx[domain]])][np.newaxis, :],
                    x_test[test_idx[domain]][np.newaxis, :]
                ]),
                y=np.concatenate([
                    y_test[np.concatenate([train_idx[domain],
                                           val_idx[domain]])][np.newaxis, :],
                    y_test[test_idx[domain]][np.newaxis, :]
                ]))

            total_tasks_test = len(test_data_ML)

            learner = maml.clone()  # Creates a clone of model
            learner.cuda()
            accum_error = 0.0
            accum_std = 0.0
            count = 0.0

            input_dim = test_data_ML.x.shape[-1]
            window_size = test_data_ML.x.shape[-2]
            output_dim = test_data_ML.y.shape[-1]

            task = 0

            model2 = nn.Linear(120, 1)
            model2.load_state_dict(copy.deepcopy(maml.module.state_dict()))

            model.cuda()
            model2.cuda()

            x_spt, y_spt = temp_train_data[task]
            x_qry = temp_train_data.x[(task + 1)]
            y_qry = temp_train_data.y[(task + 1)]

            if model_name == "FCN":
                x_qry = np.transpose(x_qry, [0, 2, 1])
                x_spt = np.transpose(x_spt, [0, 2, 1])

            x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
            x_qry = to_torch(x_qry)
            y_qry = to_torch(y_qry)

            opt2 = optim.SGD(list(model2.parameters()), lr=learning_rate)
            #learner.module.train()
            size_back = 300
            step_size = task_size * size_back

            #model2.eval()
            for step in range(adaptation_steps):

                print(step)
                #model2.train()
                for idx in range(x_spt.shape[0] - task_size * size_back,
                                 x_spt.shape[0], step_size):

                    pred = model2(model.encoder(x_spt[idx:idx + step_size]))
                    print(pred.shape)
                    print(step_size)
                    error = mae(pred, y_spt[idx:idx + step_size])
                    print(error)
                    opt2.zero_grad()
                    error.backward()

                    #learner.adapt(error)
                    opt2.step()

            #model2.eval()
            #learner.module.eval()
            step = x_qry.shape[0] // 255
            for idx in range(0, x_qry.shape[0], step):
                pred = model2(model.encoder(x_qry[idx:idx + step]))
                error = mae(pred, y_qry[idx:idx + step])

                accum_error += error.data
                accum_std += error.data**2
                count += 1

            error = accum_error / count

            y_list.append(y_qry.cpu().numpy())
            error_list.append(float(error.cpu().numpy()))
            print(np.mean(error_list))
            print(error_list)

            trials_loss_list.append(np.mean(error_list))

        print("mean:", np.mean(trials_loss_list))
        print("std:", np.std(trials_loss_list))
Beispiel #7
0
def test(loss_fn, maml, multimodal_model, task_data, dataset_name, data_ML, adaptation_steps, learning_rate, noise_level, noise_type, is_test = True, horizon = 10):
    
    total_tasks = len(data_ML)
    task_size = data_ML.x.shape[-3]
    input_dim = data_ML.x.shape[-1]
    window_size = data_ML.x.shape[-2]
    output_dim = data_ML.y.shape[-1]

    if is_test:
        step = total_tasks//100

    else:
        step = 1

    step = 1 if step == 0 else step
    grid = [0., noise_level]
    accum_error = 0.0
    count = 1.0

    for task_idx in range(0, (total_tasks-horizon-1), step):

        temp_file_idx = data_ML.file_idx[task_idx:task_idx+horizon+1]
        if(len(np.unique(temp_file_idx))>1):
            continue
            
        learner = maml.clone() 

        x_spt, y_spt = data_ML[task_idx]
        x_qry = data_ML.x[(task_idx+1):(task_idx+1+horizon)].reshape(-1, window_size, input_dim)
        y_qry = data_ML.y[(task_idx+1):(task_idx+1+horizon)].reshape(-1, output_dim)
        task = task_data[task_idx:task_idx+1].cuda()

        x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
        x_qry = to_torch(x_qry)
        y_qry = to_torch(y_qry)


        epsilon = grid[np.random.randint(0,len(grid))]

        if noise_type == "additive":
            y_spt = y_spt+epsilon
            y_qry = y_qry+epsilon

        else:
            y_spt = y_spt*(1+epsilon)
            y_qry = y_qry*(1+epsilon)

        for step in range(adaptation_steps):

            x_encoding, _  = multimodal_model(x_spt, task, output_encoding=True)
            pred = learner(x_encoding)
            error = loss_fn(pred, y_spt)
            learner.adapt(error)

        x_encoding, _  = multimodal_model(x_qry, task, output_encoding=True)
        y_pred = learner(x_encoding)
        
        y_pred = torch.clamp(y_pred, 0, 1)
        error = mae(y_pred, y_qry)
        
        accum_error += error.data
        
        count += 1
        
    error = accum_error/count

    return error.cpu().numpy()
Beispiel #8
0
def main(args):
    dataset_name = args.dataset
    model_name = args.model
    adaptation_steps = args.adaptation_steps
    meta_learning_rate = args.meta_learning_rate
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    save_model_file = args.save_model_file
    load_model_file = args.load_model_file
    lower_trial = args.lower_trial
    upper_trial = args.upper_trial
    task_size = args.task_size
    noise_level = args.noise_level
    noise_type = args.noise_type
    epochs = args.epochs
    loss_fcn_str = args.loss
    modulate_task_net = args.modulate_task_net
    weight_vrae = args.weight_vrae
    stopping_patience = args.stopping_patience
    ml_horizon = args.ml_horizon
    experiment_id = args.experiment_id

    meta_info = {"POLLUTION": [5, 14],
                 "HR": [32, 13],
                 "BATTERY": [20, 3]}

    assert model_name in ("FCN", "LSTM"), "Model was not correctly specified"
    assert dataset_name in ("POLLUTION", "HR", "BATTERY")

    window_size, input_dim = meta_info[dataset_name]

    grid = [0., noise_level]

    train_data_ML = pickle.load(
        open("../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb"))
    validation_data_ML = pickle.load(
        open("../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb"))
    test_data_ML = pickle.load(
        open("../../Data/TEST-" + dataset_name + "-W" + str(window_size) + "-T" + str(task_size) + "-ML.pickle", "rb"))

    total_tasks = len(train_data_ML)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    loss_fn = mae if loss_fcn_str == "MAE" else nn.SmoothL1Loss()

    ##multimodal learner parameters
    # paramters wto increase capactiy of the model
    n_layers_task_net = 2
    n_layers_task_encoder = 1
    n_layers_task_decoder = 1

    hidden_dim_task_net = 120
    hidden_dim_encoder = 120
    hidden_dim_decoder = 120

    # fixed values
    input_dim_task_net = input_dim
    input_dim_task_encoder = input_dim + 1
    output_dim_task_net = 1
    output_dim_task_decoder = input_dim + 1
    output_dim = 1

    results_list = []
    results_dict = {}
    results_dict["Experiment_id"] = experiment_id
    results_dict["Model"] = model_name
    results_dict["Dataset"] = dataset_name
    results_dict["Learning rate"] = learning_rate
    results_dict["Noise level"] = noise_level
    results_dict["Task size"] = task_size
    results_dict["Evaluation loss"] = "MAE Test"
    results_dict["Vrae weight"] = weight_vrae
    results_dict["Training"] = "MMAML"
    results_dict["Meta-learning rate"] = meta_learning_rate
    results_dict["ML-Horizon"] = ml_horizon

    for trial in range(lower_trial, upper_trial):

        output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MMAML/" + str(trial) + "/"
        save_model_file_ = output_directory + experiment_id + "_" + save_model_file
        save_model_file_encoder = output_directory + experiment_id + "_"+ "encoder_" + save_model_file
        load_model_file_ = output_directory + load_model_file
        checkpoint_file = output_directory + "checkpoint_" + save_model_file.split(".")[0]

        writer = SummaryWriter()

        try:
            os.mkdir(output_directory)
        except OSError as error:
            print(error)

        task_net = LSTMModel(batch_size=batch_size,
                             seq_len=window_size,
                             input_dim=input_dim_task_net,
                             n_layers=n_layers_task_net,
                             hidden_dim=hidden_dim_task_net,
                             output_dim=output_dim_task_net)

        task_encoder = LSTMModel(batch_size=batch_size,
                                 seq_len=task_size,
                                 input_dim=input_dim_task_encoder,
                                 n_layers=n_layers_task_encoder,
                                 hidden_dim=hidden_dim_encoder,
                                 output_dim=1)

        task_decoder = LSTMDecoder(batch_size=1,
                                   n_layers=n_layers_task_decoder,
                                   seq_len=task_size,
                                   output_dim=output_dim_task_decoder,
                                   hidden_dim=hidden_dim_encoder,
                                   latent_dim=hidden_dim_decoder,
                                   device=device)

        lmbd = Lambda(hidden_dim_encoder, hidden_dim_task_net)

        multimodal_learner = MultimodalLearner(task_net, task_encoder, task_decoder, lmbd, modulate_task_net)
        multimodal_learner.to(device)

        output_layer = nn.Linear(120, 1)
        output_layer.to(device)

        maml = l2l.algorithms.MAML(output_layer, lr=learning_rate, first_order=False)
        opt = optim.Adam(list(maml.parameters()) + list(multimodal_learner.parameters()), lr=meta_learning_rate)

        early_stopping = EarlyStopping(patience=stopping_patience, model_file=save_model_file_, verbose=True)
        early_stopping_encoder = EarlyStopping(patience=stopping_patience, model_file=save_model_file_encoder, verbose=True)

        task_data_train = torch.FloatTensor(get_task_encoder_input(train_data_ML))
        task_data_validation = torch.FloatTensor(get_task_encoder_input(validation_data_ML))
        task_data_test = torch.FloatTensor(get_task_encoder_input(test_data_ML))

        val_loss_hist = []
        test_loss_hist = []
        total_num_tasks = train_data_ML.x.shape[0]

        for iteration in range(epochs):

            opt.zero_grad()
            iteration_error = 0.0
            vrae_loss_accum = 0.0

            multimodal_learner.train()

            for _ in range(batch_size):
                learner = maml.clone()
                task_idx = np.random.randint(0,total_num_tasks-ml_horizon-1)
                task = task_data_train[task_idx:task_idx+1].cuda()

                if train_data_ML.file_idx[task_idx+1] != train_data_ML.file_idx[task_idx]:
                    continue

                x_spt, y_spt = train_data_ML[task_idx]
                x_qry, y_qry = train_data_ML[task_idx + ml_horizon]
                x_qry = x_qry.reshape(-1, window_size, input_dim)
                y_qry = y_qry.reshape(-1, output_dim)

                x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
                x_qry = to_torch(x_qry)
                y_qry = to_torch(y_qry)

                # data augmentation
                epsilon = grid[np.random.randint(0, len(grid))]

                if noise_type == "additive":
                    y_spt = y_spt + epsilon
                    y_qry = y_qry + epsilon
                else:
                    y_spt = y_spt * (1 + epsilon)
                    y_qry = y_qry * (1 + epsilon)

                vrae_loss_accum = 0.0

                x_spt_encoding, (vrae_loss, _, _) = multimodal_learner(x_spt, task,output_encoding=True)
                
                for _ in range(adaptation_steps):
                
                    pred = learner(x_spt_encoding)
                    error = loss_fn(pred, y_spt)
                    learner.adapt(error)#, allow_unused=True)#, allow_nograd=True)
                                                                                        
     
                vrae_loss_accum += vrae_loss

                x_qry_encoding, _ = multimodal_learner(x_qry, task, output_encoding=True)
                pred = learner(x_qry_encoding)
                evaluation_error = loss_fn(pred, y_qry)
                iteration_error += evaluation_error

            iteration_error /= batch_size
            vrae_loss_accum /= batch_size
            iteration_error += weight_vrae*vrae_loss_accum
            iteration_error.backward()
            opt.step()

            multimodal_learner.eval()
            
            val_loss = test(loss_fn, maml, multimodal_learner, task_data_validation, dataset_name, validation_data_ML, adaptation_steps, learning_rate, noise_level, noise_type,horizon=10)
            test_loss = test(loss_fn, maml, multimodal_learner, task_data_test, dataset_name, test_data_ML, adaptation_steps, learning_rate, 0, noise_type, horizon=10)
           
            early_stopping_encoder(val_loss, multimodal_learner)
            early_stopping(val_loss, maml)

            if early_stopping.early_stop:
                print("Early stopping")
                break

            print("Epoch:", iteration)

            print("Train loss:", iteration_error)
            print("Val error:", val_loss)
            print("Test error:", test_loss)

            val_loss_hist.append(val_loss)
            test_loss_hist.append(test_loss)

            writer.add_scalar("Loss/train", iteration_error.cpu().detach().numpy(), iteration)
            writer.add_scalar("Loss/val", val_loss, iteration)
            writer.add_scalar("Loss/test", test_loss, iteration)

        multimodal_learner.load_state_dict(torch.load(save_model_file_encoder))
        maml.load_state_dict(torch.load(save_model_file_))

        val_loss = test(loss_fn, maml, multimodal_learner, task_data_validation, dataset_name, validation_data_ML, adaptation_steps, learning_rate, noise_level, noise_type,horizon=10)
        test_loss = test(loss_fn, maml, multimodal_learner, task_data_test, dataset_name, test_data_ML, adaptation_steps, learning_rate, noise_level, noise_type,horizon=10)

        val_loss1 = test(loss_fn, maml, multimodal_learner, task_data_validation, dataset_name, validation_data_ML, adaptation_steps, learning_rate, noise_level, noise_type,horizon=1)
        test_loss1 = test(loss_fn, maml, multimodal_learner, task_data_test, dataset_name, test_data_ML, adaptation_steps, learning_rate, noise_level, noise_type,horizon=1)

        adaptation_steps_ = 0
        val_loss_0 = test(loss_fn, maml, multimodal_learner, task_data_validation, dataset_name, validation_data_ML, adaptation_steps_, learning_rate, noise_level, noise_type,horizon=10)
        test_loss_0 = test(loss_fn, maml, multimodal_learner, task_data_test, dataset_name, test_data_ML, adaptation_steps_, learning_rate, noise_level, noise_type,horizon=10)
        
        val_loss1_0 = test(loss_fn, maml, multimodal_learner, task_data_validation, dataset_name, validation_data_ML, adaptation_steps_, learning_rate, noise_level, noise_type,horizon=1)
        test_loss1_0 = test(loss_fn, maml, multimodal_learner, task_data_test, dataset_name, test_data_ML, adaptation_steps_, learning_rate, noise_level, noise_type,horizon=1)

        with open(output_directory + "/results3.txt", "a+") as f:
            f.write("\n \n Learning rate :%f \n"% learning_rate)
            f.write("Meta-learning rate: %f \n" % meta_learning_rate)
            f.write("Adaptation steps: %f \n" % adaptation_steps)
            f.write("Noise level: %f \n" % noise_level)
            f.write("vrae weight: %f \n" % weight_vrae)
            f.write("Test error: %f \n" % test_loss)
            f.write("Val error: %f \n" % val_loss)
            f.write("Test error 1: %f \n" % test_loss1)
            f.write("Val error 1: %f \n" % val_loss1)
            f.write("Test error 0: %f \n" % test_loss_0)
            f.write("Val error 0: %f \n" % val_loss_0)

        writer.add_hparams({"fast_lr": learning_rate,
                            "slow_lr": meta_learning_rate,
                            "adaption_steps": adaptation_steps,
                            "patience": stopping_patience,
                            "weight_vrae": weight_vrae,
                            "noise_level": noise_level,
                            "dataset": dataset_name,
                            "trial": trial},
                           {"val_loss": val_loss,
                            "test_loss": test_loss})

        temp_results_dict = copy.copy(results_dict)
        temp_results_dict["Trial"] = trial
        temp_results_dict["Adaptation steps"] = adaptation_steps
        temp_results_dict["Horizon"] = 10
        temp_results_dict["Value"] = float(test_loss)
        temp_results_dict["Val error"] = float(val_loss)
        temp_results_dict["Final_epoch"] = iteration
        results_list.append(temp_results_dict)

        temp_results_dict = copy.copy(results_dict)
        temp_results_dict["Trial"] = trial
        temp_results_dict["Adaptation steps"] = 0
        temp_results_dict["Horizon"] = 10
        temp_results_dict["Value"] = float(test_loss_0)
        temp_results_dict["Val error"] = float(val_loss_0)
        temp_results_dict["Final_epoch"] = iteration
        results_list.append(temp_results_dict)      

        temp_results_dict = copy.copy(results_dict)
        temp_results_dict["Trial"] = trial
        temp_results_dict["Adaptation steps"] = adaptation_steps
        temp_results_dict["Horizon"] = 1
        temp_results_dict["Value"] = float(test_loss1)
        temp_results_dict["Final_epoch"] = iteration
        results_list.append(temp_results_dict)

        temp_results_dict = copy.copy(results_dict)
        temp_results_dict["Trial"] = trial
        temp_results_dict["Adaptation steps"] = 0
        temp_results_dict["Horizon"] = 1
        temp_results_dict["Value"] = float(test_loss1_0)
        temp_results_dict["Final_epoch"] = iteration
        results_list.append(temp_results_dict)  


    try:
        os.mkdir("../../Results/json_files/")
    except OSError as error:
        print(error)
        
    with open("../../Results/json_files/"+experiment_id+".json", 'w') as outfile:
        json.dump(results_list, outfile)
def main(args):

    meta_info = {
        "POLLUTION": [5, 50, 14],
        "HR": [32, 50, 13],
        "BATTERY": [20, 50, 3]
    }

    output_directory = "output/"
    verbose = True
    batch_size = 64
    freeze_model_flag = True

    params = {'batch_size': batch_size, 'shuffle': True, 'num_workers': 0}

    dataset_name = args.dataset
    model_name = args.model
    learning_rate = args.learning_rate
    save_model_file = args.save_model_file
    load_model_file = args.load_model_file
    lower_trial = args.lower_trial
    upper_trial = args.upper_trial
    is_test = args.is_test
    epochs = args.epochs
    experiment_id = args.experiment_id
    adaptation_steps = args.adaptation_steps

    assert model_name in ("FCN", "LSTM"), "Model was not correctly specified"
    assert dataset_name in ("POLLUTION", "HR", "BATTERY")

    window_size, task_size, input_dim = meta_info[dataset_name]
    batch_size = 64
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    loss_fn = mae

    train_data = pickle.load(
        open(
            "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-NOML.pickle", "rb"))
    train_data_ML = pickle.load(
        open(
            "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))
    validation_data = pickle.load(
        open(
            "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" +
            str(task_size) + "-NOML.pickle", "rb"))
    validation_data_ML = pickle.load(
        open(
            "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" +
            str(task_size) + "-ML.pickle", "rb"))
    test_data = pickle.load(
        open(
            "../../Data/TEST-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-NOML.pickle", "rb"))
    test_data_ML = pickle.load(
        open(
            "../../Data/TEST-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))

    # paramters wto increase capactiy of the model
    n_layers_task_net = 2
    n_layers_task_encoder = 1
    n_layers_task_decoder = 1

    hidden_dim_task_net = 120
    hidden_dim_encoder = 120
    hidden_dim_decoder = 120

    input_dim_task_net = input_dim
    input_dim_task_encoder = input_dim + 1
    output_dim_task_net = 1
    output_dim_task_decoder = input_dim + 1
    output_dim = 1

    if is_test == 0:
        test_data = validation_data

    train_idx, val_idx, test_idx = split_idx_50_50(
        test_data.file_idx) if is_test else split_idx_50_50(
            validation_data.file_idx)
    n_domains_in_test = np.max(test_data.file_idx) + 1

    test_loss_list = []
    initial_test_loss_list = []

    trials_loss_list = []
    modulate_task_net = True

    #trial = 0
    for trial in range(lower_trial, upper_trial):

        output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MMAML/" + str(
            trial) + "/"

        #save_model_file_ = output_directory + "encoder_"+save_model_file
        #save_model_file_2 = output_directory + save_model_file
        save_model_file_encoder = output_directory + experiment_id + "_encoder_model.pt"
        save_model_file_ = output_directory + experiment_id + "_model.pt"
        load_model_file_ = output_directory + load_model_file

        ##creating the network

        task_net = LSTMModel(batch_size=batch_size,
                             seq_len=window_size,
                             input_dim=input_dim_task_net,
                             n_layers=n_layers_task_net,
                             hidden_dim=hidden_dim_task_net,
                             output_dim=output_dim_task_net)

        task_encoder = LSTMModel(batch_size=batch_size,
                                 seq_len=task_size,
                                 input_dim=input_dim_task_encoder,
                                 n_layers=n_layers_task_encoder,
                                 hidden_dim=hidden_dim_encoder,
                                 output_dim=1)

        task_decoder = LSTMDecoder(batch_size=1,
                                   n_layers=n_layers_task_decoder,
                                   seq_len=task_size,
                                   output_dim=output_dim_task_decoder,
                                   hidden_dim=hidden_dim_encoder,
                                   latent_dim=hidden_dim_decoder,
                                   device=device)
        lmbd = Lambda(hidden_dim_encoder, hidden_dim_task_net)

        multimodal_learner = MultimodalLearner(task_net, task_encoder,
                                               task_decoder, lmbd,
                                               modulate_task_net)
        multimodal_learner.to(device)

        output_layer = nn.Linear(120, 1)
        output_layer.to(device)

        maml = l2l.algorithms.MAML(output_layer,
                                   lr=learning_rate,
                                   first_order=False)

        multimodal_learner.load_state_dict(torch.load(save_model_file_encoder))
        maml.load_state_dict(torch.load(save_model_file_))

        n_domains_in_test = np.max(test_data.file_idx) + 1

        error_list = []

        y_list = []

        for domain in range(n_domains_in_test):
            print("Domain:", domain)
            x_test = test_data.x
            y_test = test_data.y

            temp_train_data = SimpleDataset(
                x=np.concatenate([
                    x_test[np.concatenate([train_idx[domain],
                                           val_idx[domain]])][np.newaxis, :],
                    x_test[test_idx[domain]][np.newaxis, :]
                ]),
                y=np.concatenate([
                    y_test[np.concatenate([train_idx[domain],
                                           val_idx[domain]])][np.newaxis, :],
                    y_test[test_idx[domain]][np.newaxis, :]
                ]))

            total_tasks_test = len(test_data_ML)

            learner = maml.clone()  # Creates a clone of model
            learner.cuda()
            accum_error = 0.0
            accum_std = 0.0
            count = 0.0

            input_dim = test_data_ML.x.shape[-1]
            window_size = test_data_ML.x.shape[-2]
            output_dim = test_data_ML.y.shape[-1]

            task_id = 0

            #model2 = nn.Linear(120, 1)
            #model2.load_state_dict(copy.deepcopy(maml.module.state_dict()))

            #model.cuda()
            #model2.cuda()
            output_layer = nn.Linear(120, 1)
            output_layer.load_state_dict(
                copy.deepcopy(maml.module.state_dict()))
            output_layer.to(device)

            x_spt, y_spt = temp_train_data[task_id]
            x_qry = temp_train_data.x[(task_id + 1)]
            y_qry = temp_train_data.y[(task_id + 1)]

            task = get_task_encoder_input(
                SimpleDataset(x=x_spt[-50:][np.newaxis, :],
                              y=y_spt[-50:][np.newaxis, :]))
            task = to_torch(task)

            if model_name == "FCN":
                x_qry = np.transpose(x_qry, [0, 2, 1])
                x_spt = np.transpose(x_spt, [0, 2, 1])

            x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
            x_qry = to_torch(x_qry)
            y_qry = to_torch(y_qry)

            opt2 = optim.SGD(list(output_layer.parameters()), lr=learning_rate)
            #learner.module.train()
            size_back = 200
            step_size = task_size * size_back

            multimodal_learner.train()
            #model2.eval()
            for step in range(adaptation_steps):

                step_size = 1
                error_accum = 0
                count = 0
                #model2.train()
                for idx in range(0, x_spt.shape[0], step_size):

                    x_spt_encoding, (vrae_loss, _, _) = multimodal_learner(
                        x_spt[idx:idx + step_size], task, output_encoding=True)
                    pred = output_layer(x_spt_encoding)
                    error_accum += mae(pred, y_spt[idx:idx + step_size])
                    count += 1

                opt2.zero_grad()
                error = error_accum / count
                error.backward()

                #learner.adapt(error)
                opt2.step()

            #model2.eval()
            #learner.module.eval()

            multimodal_learner.eval()
            step = x_qry.shape[0] // 255
            error_accum = 0
            count = 0
            for idx in range(0, x_qry.shape[0], step):

                x_qry_encoding, (vrae_loss, _,
                                 _) = multimodal_learner(x_qry[idx:idx + step],
                                                         task,
                                                         output_encoding=True)
                pred = output_layer(x_qry_encoding)
                error = mae(pred, y_qry[idx:idx + step])

                accum_error += error.data
                accum_std += error.data**2
                count += 1

            error = accum_error / count

            y_list.append(y_qry.cpu().numpy())
            error_list.append(float(error.cpu().numpy()))
            print(np.mean(error_list))
            print(error_list)

            trials_loss_list.append(np.mean(error_list))

        print("mean:", np.mean(trials_loss_list))
        print("std:", np.std(trials_loss_list))
def main(args):

    dataset_name = args.dataset
    model_name = args.model
    n_inner_iter = args.adaptation_steps
    batch_size = args.batch_size
    save_model_file = args.save_model_file
    load_model_file = args.load_model_file
    lower_trial = args.lower_trial
    upper_trial = args.upper_trial
    is_test = args.is_test
    stopping_patience = args.stopping_patience
    epochs = args.epochs
    fast_lr = args.learning_rate
    slow_lr = args.meta_learning_rate

    first_order = False
    inner_loop_grad_clip = 20
    task_size = 50
    output_dim = 1

    horizon = 10
    ##test

    meta_info = {
        "POLLUTION": [5, 50, 14],
        "HR": [32, 50, 13],
        "BATTERY": [20, 50, 3]
    }

    assert model_name in ("FCN", "LSTM"), "Model was not correctly specified"
    assert dataset_name in ("POLLUTION", "HR", "BATTERY")

    window_size, task_size, input_dim = meta_info[dataset_name]

    output_directory = "output/"

    train_data_ML = pickle.load(
        open(
            "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))
    validation_data_ML = pickle.load(
        open(
            "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" +
            str(task_size) + "-ML.pickle", "rb"))
    test_data_ML = pickle.load(
        open(
            "../../Data/TEST-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))

    for trial in range(lower_trial, upper_trial):

        output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MAML/" + str(
            trial) + "/"
        save_model_file_ = output_directory + save_model_file
        load_model_file_ = output_directory + load_model_file

        try:
            os.mkdir(output_directory)
        except OSError as error:
            print(error)

        with open(output_directory + "/results2.txt", "a+") as f:
            f.write("Learning rate :%f \n" % fast_lr)
            f.write("Meta-learning rate: %f \n" % slow_lr)
            f.write("Adaptation steps: %f \n" % n_inner_iter)
            f.write("\n")

        if model_name == "LSTM":
            model = LSTMModel(batch_size=batch_size,
                              seq_len=window_size,
                              input_dim=input_dim,
                              n_layers=2,
                              hidden_dim=120,
                              output_dim=output_dim)

        optimizer = torch.optim.Adam(model.parameters(), lr=slow_lr)
        loss_func = mae

        torch.backends.cudnn.enabled = False

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        meta_learner = MetaLearner(model, optimizer, fast_lr, loss_func,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)

        total_tasks, task_size, window_size, input_dim = train_data_ML.x.shape

        early_stopping = EarlyStopping(patience=stopping_patience,
                                       model_file=save_model_file_,
                                       verbose=True)

        for _ in range(epochs):

            #train
            batch_idx = np.random.randint(0, total_tasks - 1, batch_size)
            x_spt, y_spt = train_data_ML[batch_idx]
            x_qry, y_qry = train_data_ML[batch_idx + 1]

            x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
            x_qry = to_torch(x_qry)
            y_qry = to_torch(y_qry)

            train_tasks = [
                Task(x_spt[i], y_spt[i]) for i in range(x_spt.shape[0])
            ]
            val_tasks = [
                Task(x_qry[i], y_qry[i]) for i in range(x_qry.shape[0])
            ]

            adapted_params = meta_learner.adapt(train_tasks)
            mean_loss = meta_learner.step(adapted_params,
                                          val_tasks,
                                          is_training=True)
            print(mean_loss)

            #test
            val_error = test(validation_data_ML, meta_learner, device)
            print(val_error)

            early_stopping(val_error, meta_learner)

            if early_stopping.early_stop:
                print("Early stopping")
                break

        model.load_state_dict(torch.load(save_model_file_)["model_state_dict"])
        meta_learner = MetaLearner(model, optimizer, fast_lr, loss_func,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)

        validation_error = test(validation_data_ML, meta_learner, device)
        test_error = test(test_data_ML, meta_learner, device)

        with open(output_directory + "/results2.txt", "a+") as f:
            f.write("Test error: %f \n" % test_error)
            f.write("Validation error: %f \n" % validation_error)

        print(test_error)
        print(validation_error)