Exemplo n.º 1
0
def main():
    p = class_parser.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)


    utils.set_seed(args['seed'])

    my_experiment = experiment(args['name'], args, "../results/", commit_changes=False, rank=0, seed=1)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args['classes'] = list(range(963))

    args['traj_classes'] = list(range(int(963/2), 963))


    dataset = df.DatasetFactory.get_dataset(args['dataset'], background=True, train=True,path=args["path"], all=True)
    dataset_test = df.DatasetFactory.get_dataset(args['dataset'], background=True, train=False, path=args["path"], all=True)

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test, batch_size=5,
                                                shuffle=True, num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset, batch_size=5,
                                                 shuffle=True, num_workers=1)

    sampler = ts.SamplerFactory.get_sampler(args['dataset'], args['classes'], dataset, dataset_test)

    config = mf.ModelFactory.get_model("na", args['dataset'], output_dimension=1000)

    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device('cuda:' + str(gpu_to_use))
        logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
    else:
        device = torch.device('cpu')

    maml = MetaLearingClassification(args, config).to(device)


    for step in range(args['steps']):

        t1 = np.random.choice(args['traj_classes'], args['tasks'], replace=False)
        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))
        d_rand_iterator = sampler.get_complete_iterator()
        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(d_traj_iterators, d_rand_iterator,
                                                               steps=args['update_step'], reset=not args['no_reset'])
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)
from experiment.experiment import experiment
from utils import utils
from torch import nn
from model import lstm
from copy import deepcopy

gamma = 0.9

logger = logging.getLogger('experiment')

p = reg_parser.Parser()
total_seeds = len(p.parse_known_args()[0].seed)
rank = p.parse_known_args()[0].run
all_args = vars(p.parse_known_args()[0])

args = utils.get_run(all_args, rank)

my_experiment = experiment(args["name"],
                           args,
                           args["output_dir"],
                           sql=True,
                           run=int(rank / total_seeds),
                           seed=total_seeds)

my_experiment.results["all_args"] = all_args
my_experiment.make_table("error_table", {
    "run": 0,
    "step": 0,
    "error": 0.0
}, ("run", "step"))
my_experiment.make_table(
Exemplo n.º 3
0
def main():
    p = reg_parser.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])

    args = utils.get_run(all_args, rank)

    my_experiment = experiment(args["name"], args, args["output_dir"], commit_changes=False,
                               rank=int(rank / total_seeds),
                               seed=total_seeds)

    my_experiment.results["all_args"] = all_args



    logger = logging.getLogger('experiment')

    gradient_error_list = []
    gradient_alignment_list = []

    for seed in range(args["runs"]):
        utils.set_seed(args["seed"] + seed + seed*args["seed"])
        n = Recurrent_Network(50, args['columns'], args["width"],
                              args["sparsity"])
        error_grad_mc = 0

        rnn_state = torch.zeros(args['columns'])
        n.reset_TH()

        for ind in range(50):

            x = torch.bernoulli(torch.zeros(1, 50) + 0.5)

            _, _, grads = n.forward(x, rnn_state, grad=True, retain_graph=False, bptt=False)

            value_prediction, rnn_state, _ = n.forward(x, rnn_state, grad=False,
                                                       retain_graph=False, bptt=True)

            n.update_TH(grads)

            target_random = random.random() * 100 - 50
            real_error = (0.5) * (target_random - value_prediction) ** 2
            error_grad_mc += real_error

            n.accumulate_gradients(target_random, value_prediction, hidden_state=rnn_state)

        grads = torch.autograd.grad(error_grad_mc, n.parameters())

        counter = 0
        total_sum = 0
        positive_sum = 0
        dif = 0

        for named, param in n.named_parameters():
            # if "prediction" in named:
            #     counter+=1
            #     continue
            # print(named)
            # print(grads[counter], n.grads[named])
            dif += torch.abs(n.grads[named] - grads[counter]).sum()
            positive = ((n.grads[named] * grads[counter]) > 1e-10).float().sum()
            total = positive + ((n.grads[named] * grads[counter]) < - 1e-10).float().sum()
            total_sum += total
            positive_sum += positive

            counter += 1

        logger.error("Difference = %s", (float(dif) / total_sum).item())
        gradient_error_list.append( (float(dif) / total_sum).item())
        gradient_alignment_list.append(str(float(positive_sum) / float(total_sum)))
        logger.error("Grad alignment %s", str(float(positive_sum) / float(total_sum)))




        my_experiment.add_result("abs_error", str(gradient_error_list))
        my_experiment.add_result("alignment", str(gradient_alignment_list))

        my_experiment.store_json()
Exemplo n.º 4
0
def main():
    p = class_parser.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)

    utils.set_seed(args["seed"])

    if args["log_root"]:
        log_root = osp.join("./results", args["log_root"]) + "/"
    else:
        log_root = osp.join("./results/")

    my_experiment = experiment(
        args["name"],
        args,
        log_root,
        commit_changes=False,
        rank=0,
        seed=args["seed"],
    )
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger("experiment")

    # Using first 963 classes of the omniglot as the meta-training set
    # args["classes"] = list(range(963))
    args["classes"] = list(range(args["num_classes"]))
    print("Using classes:", args["num_classes"])
    # logger.info("Using classes:", str(args["num_classes"]))

    # args["traj_classes"] = list(range(int(963 / 2), 963))

    if torch.cuda.is_available():
        device = torch.device("cuda")
        use_cuda = True
    else:
        device = torch.device("cpu")
        use_cuda = False
    dataset_spt = df.DatasetFactory.get_dataset(
        args["dataset"],
        background=True,
        train=True,
        path=args["path"],
        # all=True,
        # all=False,
        all=args["all"],
        prefetch_gpu=args["prefetch_gpu"],
        device=device,
        resize=args["resize"],
        augment=args["augment_spt"],
    )
    dataset_qry = df.DatasetFactory.get_dataset(
        args["dataset"],
        background=True,
        train=True,
        path=args["path"],
        # all=True,
        # all=False,
        all=args["all"],
        prefetch_gpu=args["prefetch_gpu"],
        device=device,
        resize=args["resize"],
        augment=args["augment_qry"],
    )
    dataset_test = df.DatasetFactory.get_dataset(
        args["dataset"],
        background=True,
        train=False,
        path=args["path"],
        # all=True,
        # all=False,
        all=args["all"],
        resize=args["resize"],
        # augment=args["augment"],
    )

    logger.info(
        f"Support size: {len(dataset_spt)}, Query size: {len(dataset_qry)}, test size: {len(dataset_test)}"
    )
    # print(f"Support size: {len(dataset_spt)}, Query size: {len(dataset_qry)}, test size: {len(dataset_test)}")

    pin_memory = use_cuda
    if args["prefetch_gpu"]:
        num_workers = 0
        pin_memory = False
    else:
        num_workers = args["num_workers"]
    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=5,
        shuffle=True,
        num_workers=0,
        # pin_memory=pin_memory,
    )

    iterator_train = torch.utils.data.DataLoader(
        dataset_spt,
        batch_size=5,
        shuffle=True,
        num_workers=0,
        # pin_memory=pin_memory,
    )

    logger.info("Support sampler:")
    sampler_spt = ts.SamplerFactory.get_sampler(
        args["dataset"],
        args["classes"],
        dataset_spt,
        dataset_test,
        prefetch_gpu=args["prefetch_gpu"],
        use_cuda=use_cuda,
        num_workers=0,
    )
    logger.info("Query sampler:")
    sampler_qry = ts.SamplerFactory.get_sampler(
        args["dataset"],
        args["classes"],
        dataset_qry,
        dataset_test,
        prefetch_gpu=args["prefetch_gpu"],
        use_cuda=use_cuda,
        num_workers=0,
    )

    config = mf.ModelFactory.get_model(
        "na",
        args["dataset"],
        output_dimension=1000,
        resize=args["resize"],
    )

    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device("cuda:" + str(gpu_to_use))
        logger.info("Using gpu : %s", "cuda:" + str(gpu_to_use))
    else:
        device = torch.device("cpu")

    maml = MetaLearingClassification(args, config).to(device)

    for step in range(args["steps"]):

        t1 = np.random.choice(args["classes"], args["tasks"], replace=False)

        d_traj_iterators_spt = []
        d_traj_iterators_qry = []
        for t in t1:
            d_traj_iterators_spt.append(sampler_spt.sample_task([t]))
            d_traj_iterators_qry.append(sampler_qry.sample_task([t]))

        d_rand_iterator = sampler_spt.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data_paper(
            d_traj_iterators_spt,
            d_traj_iterators_qry,
            d_rand_iterator,
            steps=args["update_step"],
            reset=not args["no_reset"],
        )
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = (
                x_spt.to(device),
                y_spt.to(device),
                x_qry.to(device),
                y_qry.to(device),
            )

        #
        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)

        # Evaluation during training for sanity checks
        if step % 40 == 5:
            writer.add_scalar("/metatrain/train/accuracy", accs[-1], step)
            writer.add_scalar("/metatrain/train/loss", loss[-1], step)
            writer.add_scalar("/metatrain/train/accuracy0", accs[0], step)
            writer.add_scalar("/metatrain/train/loss0", loss[0], step)
            logger.info("step: %d \t training acc %s", step, str(accs))
            logger.info("step: %d \t training loss %s", step, str(loss))
        # Currently useless
        if (step % 300 == 3) or ((step + 1) == args["steps"]):
            torch.save(maml.net, my_experiment.path + "learner.model")
Exemplo n.º 5
0
def main():
    p = reg_parser.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    run = p.parse_known_args()[0].run
    all_args = vars(p.parse_known_args()[0])

    args = utils.get_run(all_args, run)

    my_experiment = experiment(args["name"],
                               args,
                               args["output_dir"],
                               sql=True,
                               run=int(run / total_seeds),
                               seed=total_seeds)

    my_experiment.results["all_args"] = all_args

    my_experiment.make_table("metrics", {
        "run": 0,
        "meta_loss": 0.0,
        "step": 0
    }, ("run", "step"))

    metrics_keys = ["run", "meta_loss", "step"]

    logger = logging.getLogger('experiment')
    tasks = list(range(400))

    sampler = ts.SamplerFactory.get_sampler("Sin",
                                            tasks,
                                            None,
                                            capacity=args["capacity"] + 1)
    model_config = mf.ModelFactory.get_model(args["model"],
                                             "Sin",
                                             input_dimension=args["capacity"] +
                                             1,
                                             output_dimension=1,
                                             width=args["width"],
                                             cols=args["cols"])
    gpu_to_use = run % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device('cuda:' + str(gpu_to_use))
        logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
    else:
        device = torch.device('cpu')
    if args.get('update_rule') == "RTRL":
        logger.info("Columnar Net based gradient approximation...")
        metalearner = MetaLearnerRegressionCol(args,
                                               model_config,
                                               device=device).to(device)
    else:
        logger.info("BPTT update rule...")
        metalearner = MetaLearnerRegression(args, model_config,
                                            device=device).to(device)
    tmp = filter(lambda x: x.requires_grad, metalearner.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logger.info('Total trainable tensors: %d', num)

    running_meta_loss = 0
    adaptation_loss = 0
    loss_history = []
    metrics_list = []
    metrics_keys = ["run", "meta_loss", "step"]
    adaptation_loss_history = []
    adaptation_running_loss_history = []
    meta_steps_counter = 0
    LOG_INTERVAL = 2
    for step in range(args["epoch"]):
        if step % LOG_INTERVAL == 0:
            logger.debug("####\t STEP %d \t####", step)
        net = metalearner.net
        meta_steps_counter += 1
        t1 = np.random.choice(tasks, args["tasks"], replace=False)
        iterators = []

        for t in t1:
            iterators.append(sampler.sample_task([t]))

        x_traj_meta, y_traj_meta, x_rand_meta, y_rand_meta = utils.construct_set(
            iterators, sampler, steps=1)
        x_traj_meta, x_rand_meta, y_traj_meta, y_rand_meta = x_traj_meta.view(
            -1, 51), x_rand_meta.view(-1, 51), y_traj_meta.view(
                -1, 2), y_rand_meta.view(-1, 2)
        if torch.cuda.is_available():
            x_traj_meta, y_traj_meta, x_rand_meta, y_rand_meta = x_traj_meta.to(
                device), y_traj_meta.to(device), x_rand_meta.to(
                    device), y_rand_meta.to(device)

        meta_loss = metalearner(x_traj_meta, y_traj_meta, x_rand_meta,
                                y_rand_meta)
        loss_history.append(meta_loss[-1].detach().cpu().item())

        running_meta_loss = running_meta_loss * 0.97 + 0.03 * meta_loss[
            -1].detach().cpu()
        running_meta_loss_fixed = running_meta_loss / (1 -
                                                       (0.97**
                                                        (meta_steps_counter)))
        metrics_list.append((run, running_meta_loss_fixed.item(), step))

        if step % LOG_INTERVAL == 0:
            if running_meta_loss > 0:
                logger.info("Running meta loss = %f",
                            running_meta_loss_fixed.item())

            with torch.no_grad():
                t1 = np.random.choice(tasks, args["tasks"], replace=False)

                iterators = []
                for t in t1:
                    iterators.append(sampler.sample_task([t]))

                x_traj, y_traj, x_rand, y_rand = utils.construct_set(iterators,
                                                                     sampler,
                                                                     steps=1)
                x_traj, x_rand, y_traj, y_rand = x_traj.view(
                    -1,
                    51), x_rand.view(-1,
                                     51), y_traj.view(-1,
                                                      2), y_rand.view(-1, 2)
                if torch.cuda.is_available():
                    x_traj, y_traj, x_rand, y_rand = x_traj.to(
                        device), y_traj.to(device), x_rand.to(
                            device), y_rand.to(device)
                logits_select = []
                for i in range(len(x_rand)):
                    l, _, _ = net.forward_col(x_rand[i], vars=None, grad=False)
                    logits_select.append(l)

                logits = torch.stack(logits_select).unsqueeze(1)

                current_adaptation_loss = F.mse_loss(logits,
                                                     y_rand[:, 0].unsqueeze(1))
                adaptation_loss_history.append(
                    current_adaptation_loss.detach().item())
                adaptation_loss = adaptation_loss * 0.97 + current_adaptation_loss.detach(
                ).cpu().item() * 0.03
                adaptation_loss_fixed = adaptation_loss / (1 -
                                                           (0.97**(step + 1)))
                adaptation_running_loss_history.append(adaptation_loss_fixed)

                logger.info("Adaptation loss = %f", current_adaptation_loss)

                if step % LOG_INTERVAL == 0:
                    logger.info("Running adaptation loss = %f",
                                adaptation_loss_fixed)

        if (step + 1) % (LOG_INTERVAL * 500) == 0:
            if not args["no_save"]:
                torch.save(metalearner.net, my_experiment.path + "net.model")
            dict_names = {}
            for (name, param) in metalearner.net.named_parameters():
                dict_names[name] = param.adaptation

            my_experiment.insert_values("metrics", metrics_keys, metrics_list)
            metrics_list = []

            my_experiment.add_result("Layers meta values", dict_names)
            my_experiment.add_result("Meta loss", loss_history)
            my_experiment.add_result("Adaptation loss",
                                     adaptation_loss_history)
            my_experiment.add_result("Running adaption loss",
                                     adaptation_running_loss_history)
            my_experiment.store_json()
Exemplo n.º 6
0
def main():
    p = reg_parser.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])

    args = utils.get_run(all_args, rank)

    my_experiment = exp.experiment(args["name"],
                                   args,
                                   args["output_dir"],
                                   commit_changes=False,
                                   rank=int(rank / total_seeds),
                                   seed=total_seeds)

    my_experiment.results["all_args"] = all_args

    param_dict = {}
    results_dir = args["path"]
    for experiment in os.listdir(results_dir):
        if "DS_St" not in experiment:
            # print(experiment)
            for run in os.listdir(os.path.join(results_dir, experiment)):
                if "DS_St" not in run:

                    print(
                        os.path.join(results_dir, experiment, run,
                                     "metadata.json"))
                    try:
                        path = os.path.join(results_dir, experiment, run,
                                            "metadata.json")
                        with open(path) as json_file:

                            data_temp = json.load(json_file)
                            experiment_name = str(data_temp['params'])
                            param_dict[experiment_name] = data_temp['params']
                            if experiment_name in all_experiments:
                                all_experiments[experiment_name].append(
                                    data_temp['results'][args["metric"]])
                                # print(data_temp['results']['Real_Error_list'])
                            else:

                                all_experiments[experiment_name] = [
                                    data_temp['results'][args["metric"]]
                                ]
                                # print(data_temp['results']['Real_Error_list'])

                    except:

                        pass

    sns.set(style="whitegrid")
    sns.set_context("paper", font_scale=0.4, rc={"lines.linewidth": 1.0})

    truncation_dict = []
    for experiment_name in all_experiments:
        experiment_params = param_dict[experiment_name]
        temp = experiment_params["truncation"]
        if temp not in truncation_dict:

            truncation_dict.append(temp)

    for a in truncation_dict:
        counter = 0
        x = []
        y = []
        error = []
        for experiment_name in all_experiments:

            experiment_params = param_dict[experiment_name]

            if experiment_params["width"] == 50 and experiment_params[
                    "columns"] == 20 and experiment_params["truncation"] == a:

                for list_of_vals in all_experiments[experiment_name]:
                    # print(list_of_vals)
                    # print(d.strip("[").strip("]").strip("\,"))
                    y_pred = ast.literal_eval(list_of_vals)
                    y_pred = [float(x) for x in y_pred]

                    y_pred_mean = np.mean(y_pred)
                    y_pred_error = np.std(y_pred) / np.sqrt(len(y_pred))
                    x_cur = experiment_params["sparsity"]
                    x.append(x_cur)
                    y.append(y_pred_mean)
                    # print(x)
                    error.append(y_pred_error)
                    # d_sparse = []
                    # running_sum = y_pred[0]
                    # for number, value_in in enumerate(y_pred):
                    #     running_sum = running_sum * 0.96 + value_in * 0.04
                    #     if number % 50 == 0:
                    #         d_sparse.append(running_sum)
                    # print(d_new)

        x = np.array(x)
        if args["log"]:
            x = np.log10(x)
        y = np.array(y)
        arg_sort = np.argsort(x)
        x = np.array([x[p] for p in arg_sort])
        y = np.array([y[p] for p in arg_sort])
        error = np.array([error[p] for p in arg_sort])

        plt.fill_between(x, y - error, y + error, alpha=0.4)
        plt.plot(x, y)
        plt.ylim(0.4, 1)
        # plt.xlim(0, 10)

        plt.tight_layout()
        print(my_experiment.path + "result.pdf")
    plt.legend(truncation_dict)
    plt.savefig(my_experiment.path + "result.pdf", format="pdf")
Exemplo n.º 7
0
def main():
    p = reg_parser.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    args = utils.get_run(vars(p.parse_known_args()[0]), rank)
    utils.set_seed(args["seed"])

    my_experiment = experiment(args["name"],
                               args,
                               "../results/",
                               commit_changes=False,
                               rank=int(rank / total_seeds),
                               seed=total_seeds)

    my_experiment.results["all_args"] = all_args
    writer = SummaryWriter(my_experiment.path + "tensorboard")
    logger = logging.getLogger('experiment')
    pprint(args)

    tasks = list(range(400))

    sampler = ts.SamplerFactory.get_sampler("Sin",
                                            tasks,
                                            None,
                                            capacity=args["capacity"] + 1)

    model_config = mf.ModelFactory.get_model(args["model"],
                                             "Sin",
                                             input_dimension=args["capacity"] +
                                             1,
                                             output_dimension=1,
                                             width=args["width"])
    context_backbone_config = None
    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device('cuda:' + str(gpu_to_use))
        logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
    else:
        device = torch.device('cpu')

    metalearner = MetaLearnerRegression(args, model_config,
                                        context_backbone_config).to(device)
    tmp = filter(lambda x: x.requires_grad, metalearner.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logger.info('Total trainable tensors: %d', num)
    #
    running_meta_loss = 0
    adaptation_loss = 0
    loss_history = []
    adaptation_loss_history = []
    adaptation_running_loss_history = []
    meta_steps_counter = 0
    LOG_INTERVAL = 50
    for step in range(args["epoch"]):
        if step % LOG_INTERVAL == 0:
            logger.debug("####\t STEP %d \t####", step)
        net = metalearner.net
        meta_steps_counter += 1
        t1 = np.random.choice(tasks, args["tasks"], replace=False)
        iterators = []

        for t in t1:
            iterators.append(sampler.sample_task([t]))

        x_traj_meta, y_traj_meta, x_rand_meta, y_rand_meta = utils.construct_set(
            iterators, sampler, steps=args["update_step"])
        x_traj_meta, y_traj_meta = x_traj_meta.view(-1, 1,
                                                    51), y_traj_meta.view(
                                                        -1, 1, 2)

        if torch.cuda.is_available():
            x_traj_meta, y_traj_meta, x_rand_meta, y_rand_meta = x_traj_meta.to(
                device), y_traj_meta.to(device), x_rand_meta.to(
                    device), y_rand_meta.to(device)

        meta_loss = metalearner(x_traj_meta, y_traj_meta, x_rand_meta,
                                y_rand_meta)
        loss_history.append(meta_loss[-1].detach().cpu().item())

        running_meta_loss = running_meta_loss * 0.97 + 0.03 * meta_loss[
            -1].detach().cpu()
        running_meta_loss_fixed = running_meta_loss / (1 -
                                                       (0.97**
                                                        (meta_steps_counter)))
        writer.add_scalar('/metatrain/train/accuracy',
                          meta_loss[-1].detach().cpu(), meta_steps_counter)
        writer.add_scalar('/metatrain/train/runningaccuracy',
                          running_meta_loss_fixed, meta_steps_counter)

        if step % LOG_INTERVAL == 0:
            if running_meta_loss > 0:
                logger.info("Running meta loss = %f",
                            running_meta_loss_fixed.item())

            with torch.no_grad():
                t1 = np.random.choice(tasks, args["tasks"], replace=False)

                iterators = []
                for t in t1:
                    iterators.append(sampler.sample_task([t]))

                x_traj, y_traj, x_rand, y_rand = utils.construct_set(
                    iterators, sampler, steps=args["update_step"])
                x_traj, y_traj = x_traj.view(-1, 1, 51), y_traj.view(-1, 1, 2)
                if torch.cuda.is_available():
                    x_traj, y_traj, x_rand, y_rand = x_traj.to(
                        device), y_traj.to(device), x_rand.to(
                            device), y_rand.to(device)

                logits = net(x_rand[0], vars=None)
                logits_select = []
                assert y_rand[0, :, 1].sum() == 0
                for no, val in enumerate(y_rand[0, :, 1].long()):
                    logits_select.append(logits[no, val])
                logits = torch.stack(logits_select).unsqueeze(1)

                current_adaptation_loss = F.mse_loss(
                    logits, y_rand[0, :, 0].unsqueeze(1))
                adaptation_loss_history.append(
                    current_adaptation_loss.detach().item())
                adaptation_loss = adaptation_loss * 0.97 + current_adaptation_loss.detach(
                ).cpu().item() * 0.03
                adaptation_loss_fixed = adaptation_loss / (1 -
                                                           (0.97**(step + 1)))
                adaptation_running_loss_history.append(adaptation_loss_fixed)

                logger.info("Adaptation loss = %f", current_adaptation_loss)

                if step % LOG_INTERVAL == 0:
                    logger.info("Running adaptation loss = %f",
                                adaptation_loss_fixed)
                writer.add_scalar('/learn/test/adaptation_loss',
                                  current_adaptation_loss, step)

        if (step + 1) % (LOG_INTERVAL * 500) == 0:
            if not args["no_save"]:
                torch.save(metalearner.net, my_experiment.path + "net.model")
            dict_names = {}
            for (name, param) in metalearner.net.named_parameters():
                dict_names[name] = param.adaptation

            my_experiment.add_result("Layers meta values", dict_names)
            my_experiment.add_result("Meta loss", loss_history)
            my_experiment.add_result("Adaptation loss",
                                     adaptation_loss_history)
            my_experiment.add_result("Running adaption loss",
                                     adaptation_running_loss_history)
            my_experiment.store_json()