Пример #1
0
def main(args):
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    np.random.seed(args.seed)

    my_experiment = experiment(args.name, args, "./results/")

    args.classes = list(range(64))

    # args.traj_classes = list(range(int(64 / 2), 963))

    dataset = imgnet.MiniImagenet(args.dataset_path, mode='train')

    dataset_test = imgnet.MiniImagenet(args.dataset_path, mode='test')

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator = torch.utils.data.DataLoader(dataset,
                                           batch_size=128,
                                           shuffle=True,
                                           num_workers=1)

    #
    logger.info(str(args))

    config = mf.ModelFactory.get_model("na", args.dataset)

    maml = learner.Learner(config).to(device)

    opt = torch.optim.Adam(maml.parameters(), lr=args.lr)

    for e in range(args.epoch):
        correct = 0
        for img, y in tqdm(iterator):
            if e == 50:
                opt = torch.optim.Adam(maml.parameters(), lr=0.00001)
                logger.info("Changing LR from %f to %f", 0.0001, 0.00001)
            img = img.to(device)
            y = y.to(device)
            pred = maml(img)
            feature = maml(img, feature=True)
            loss_rep = torch.abs(feature).sum()

            opt.zero_grad()
            loss = F.cross_entropy(pred, y)
            # loss_rep.backward(retain_graph=True)
            # logger.info("L1 norm = %s", str(loss_rep.item()))
            loss.backward()
            opt.step()
            correct += (pred.argmax(1) == y).sum().float() / len(y)
        logger.info("Accuracy at epoch %d = %s", e,
                    str(correct / len(iterator)))

        # correct = 0
        # with torch.no_grad():
        #     for img, y in tqdm(iterator_test):
        #
        #         img = img.to(device)
        #         y = y.to(device)
        #         pred = maml(img)
        #         feature = maml(img, feature=True)
        #         loss_rep = torch.abs(feature).sum()
        #
        #         correct += (pred.argmax(1) == y).sum().float() / len(y)
        #     logger.info("Accuracy Test at epoch %d = %s", e, str(correct / len(iterator_test)))

        torch.save(maml,
                   my_experiment.path + "baseline_pretraining_imagenet.net")
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    my_experiment = experiment(args.name,
                               args,
                               "/data5/jlindsey/continual/results",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    args.classes = list(range(963))

    print('dataset', args.dataset, args.dataset == "imagenet")

    if args.dataset != "imagenet":

        dataset = df.DatasetFactory.get_dataset(args.dataset,
                                                background=True,
                                                train=True,
                                                all=True)
        dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                     background=True,
                                                     train=False,
                                                     all=True)

    else:
        args.classes = list(range(64))
        dataset = imgnet.MiniImagenet(args.imagenet_path, mode='train')
        dataset_test = imgnet.MiniImagenet(args.imagenet_path, mode='test')

    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    logger.info("Train set length = %d", len(iterator_train) * 5)
    logger.info("Test set length = %d", len(iterator_test) * 5)
    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset_test)

    config = mf.ModelFactory.get_model(
        args.model_type,
        args.dataset,
        width=args.width,
        num_extra_dense_layers=args.num_extra_dense_layers)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    if args.oja or args.hebb:
        maml = OjaMetaLearingClassification(args, config).to(device)
    else:
        print('starting up')
        maml = MetaLearingClassification(args, config).to(device)

    import sys
    if args.from_saved:
        maml.net = torch.load(args.model)
        if args.use_derivative:
            maml.net.use_derivative = True
        maml.net.optimize_out = args.optimize_out
        if maml.net.optimize_out:
            maml.net.feedback_strength_vars.append(
                torch.nn.Parameter(maml.net.init_feedback_strength *
                                   torch.ones(1).cuda()))

        if args.reset_feedback_strength:
            for fv in maml.net.feedback_strength_vars:
                w = nn.Parameter(torch.ones_like(fv) * args.feedback_strength)
                fv.data = w

        if args.reset_feedback_vars:

            print('howdy', maml.net.num_feedback_layers)

            maml.net.feedback_vars = nn.ParameterList()
            maml.net.feedback_vars_bundled = []

            maml.net.vars_plasticity = nn.ParameterList()
            maml.net.plasticity = nn.ParameterList()

            maml.net.neuron_plasticity = nn.ParameterList()

            maml.net.layer_plasticity = nn.ParameterList()

            starting_width = 84
            cur_width = starting_width
            num_outputs = maml.net.config[-1][1][0]
            for i, (name, param) in enumerate(maml.net.config):
                print('yo', i, name, param)
                if name == 'conv2d':
                    print('in conv2d')
                    stride = param[4]
                    padding = param[5]

                    #print('cur_width', cur_width, param[3])
                    cur_width = (cur_width + 2 * padding - param[3] +
                                 stride) // stride

                    maml.net.vars_plasticity.append(
                        nn.Parameter(torch.ones(*param[:4]).cuda()))
                    maml.net.vars_plasticity.append(
                        nn.Parameter(torch.ones(param[0]).cuda()))
                    #self.activations_list.append([])
                    maml.net.plasticity.append(
                        nn.Parameter(
                            maml.net.init_plasticity *
                            torch.ones(param[0], param[1] * param[2] *
                                       param[3]).cuda()))  #not implemented
                    maml.net.neuron_plasticity.append(
                        nn.Parameter(torch.zeros(1).cuda()))  #not implemented

                    maml.net.layer_plasticity.append(
                        nn.Parameter(maml.net.init_plasticity *
                                     torch.ones(1).cuda()))  #not implemented

                    feedback_var = []

                    for fl in range(maml.net.num_feedback_layers):
                        print('doing fl')
                        in_dim = maml.net.width
                        out_dim = maml.net.width
                        if fl == maml.net.num_feedback_layers - 1:
                            out_dim = param[0] * cur_width * cur_width
                        if fl == 0:
                            in_dim = num_outputs
                        feedback_w_shape = [out_dim, in_dim]
                        feedback_w = nn.Parameter(
                            torch.ones(feedback_w_shape).cuda())
                        feedback_b = nn.Parameter(torch.zeros(out_dim).cuda())
                        torch.nn.init.kaiming_normal_(feedback_w)
                        feedback_var.append((feedback_w, feedback_b))
                        print('adding')
                        maml.net.feedback_vars.append(feedback_w)
                        maml.net.feedback_vars.append(feedback_b)

                    #maml.net.feedback_vars_bundled.append(feedback_var)
                    #maml.net.feedback_vars_bundled.append(None)#bias feedback -- not implemented

                    #'''

                    maml.net.feedback_vars_bundled.append(
                        nn.Parameter(torch.zeros(
                            1)))  #weight feedback -- not implemented
                    maml.net.feedback_vars_bundled.append(
                        nn.Parameter(
                            torch.zeros(1)))  #bias feedback -- not implemented

                elif name == 'linear':
                    maml.net.vars_plasticity.append(
                        nn.Parameter(torch.ones(*param).cuda()))
                    maml.net.vars_plasticity.append(
                        nn.Parameter(torch.ones(param[0]).cuda()))
                    #self.activations_list.append([])
                    maml.net.plasticity.append(
                        nn.Parameter(maml.net.init_plasticity *
                                     torch.ones(*param).cuda()))
                    maml.net.neuron_plasticity.append(
                        nn.Parameter(maml.net.init_plasticity *
                                     torch.ones(param[0]).cuda()))
                    maml.net.layer_plasticity.append(
                        nn.Parameter(maml.net.init_plasticity *
                                     torch.ones(1).cuda()))

                    feedback_var = []

                    for fl in range(maml.net.num_feedback_layers):
                        in_dim = maml.net.width
                        out_dim = maml.net.width
                        if fl == maml.net.num_feedback_layers - 1:
                            out_dim = param[0]
                        if fl == 0:
                            in_dim = num_outputs
                        feedback_w_shape = [out_dim, in_dim]
                        feedback_w = nn.Parameter(
                            torch.ones(feedback_w_shape).cuda())
                        feedback_b = nn.Parameter(torch.zeros(out_dim).cuda())
                        torch.nn.init.kaiming_normal_(feedback_w)
                        feedback_var.append((feedback_w, feedback_b))
                        maml.net.feedback_vars.append(feedback_w)
                        maml.net.feedback_vars.append(feedback_b)
                    maml.net.feedback_vars_bundled.append(feedback_var)
                    maml.net.feedback_vars_bundled.append(
                        None)  #bias feedback -- not implemented

        maml.init_stuff(args)

    maml.net.optimize_out = args.optimize_out
    if maml.net.optimize_out:
        maml.net.feedback_strength_vars.append(
            torch.nn.Parameter(maml.net.init_feedback_strength *
                               torch.ones(1).cuda()))
    #I recently un-indented this until the maml.init_opt() line.  If stuff stops working, try re-indenting this block
    if args.zero_non_output_plasticity:
        for index in range(len(maml.net.vars_plasticity) - 2):
            maml.net.vars_plasticity[index] = torch.nn.Parameter(
                maml.net.vars_plasticity[index] * 0)
        if args.oja or args.hebb:
            for index in range(len(maml.net.plasticity) - 1):
                if args.plasticity_rank1:
                    maml.net.plasticity[index] = torch.nn.Parameter(
                        torch.zeros(1).cuda())
                else:
                    maml.net.plasticity[index] = torch.nn.Parameter(
                        maml.net.plasticity[index] * 0)
                    maml.net.layer_plasticity[index] = torch.nn.Parameter(
                        maml.net.layer_plasticity[index] * 0)
                    maml.net.neuron_plasticity[index] = torch.nn.Parameter(
                        maml.net.neuron_plasticity[index] * 0)

        if args.oja or args.hebb:
            for index in range(len(maml.net.vars_plasticity) - 2):
                maml.net.vars_plasticity[index] = torch.nn.Parameter(
                    maml.net.vars_plasticity[index] * 0)
    if args.zero_all_plasticity:
        print('zeroing plasticity')
        for index in range(len(maml.net.vars_plasticity)):
            maml.net.vars_plasticity[index] = torch.nn.Parameter(
                maml.net.vars_plasticity[index] * 0)
        for index in range(len(maml.net.plasticity)):
            if args.plasticity_rank1:
                maml.net.plasticity[index] = torch.nn.Parameter(
                    torch.zeros(1).cuda())
            else:

                maml.net.plasticity[index] = torch.nn.Parameter(
                    maml.net.plasticity[index] * 0)
                maml.net.layer_plasticity[index] = torch.nn.Parameter(
                    maml.net.layer_plasticity[index] * 0)
                maml.net.neuron_plasticity[index] = torch.nn.Parameter(
                    maml.net.neuron_plasticity[index] * 0)

    print('heyy', maml.net.feedback_vars)
    maml.init_opt()
    for name, param in maml.named_parameters():
        param.learn = True
    for name, param in maml.net.named_parameters():
        param.learn = True

    if args.freeze_out_plasticity:
        maml.net.plasticity[-1].requires_grad = False
    total_ff_vars = 2 * (6 + 2 + args.num_extra_dense_layers)
    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("net.vars." + str(temp))

    for temp in range(args.rln_end * 2):
        frozen_layers.append("net.vars." + str(total_ff_vars - 1 - temp))
    for name, param in maml.named_parameters():
        # logger.info(name)
        if name in frozen_layers:
            logger.info("RLN layer %s", str(name))
            param.learn = False

    # Update the classifier
    list_of_params = list(filter(lambda x: x.learn, maml.parameters()))
    list_of_names = list(filter(lambda x: x[1].learn, maml.named_parameters()))

    for a in list_of_names:
        logger.info("TLN layer = %s", a[0])

    for step in range(args.steps):
        '''
        print('plasticity')
        for p in maml.net.plasticity:
            print(p.size(), torch.sum(p), p)
        '''
        t1 = np.random.choice(
            args.classes, args.tasks, replace=False
        )  #np.random.randint(1, args.tasks + 1), replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            iid=args.iid)

        perm = np.random.permutation(args.tasks)

        old = []
        for i in range(y_spt.size()[0]):
            num = int(y_spt[i].cpu().numpy())
            if num not in old:
                old.append(num)
            y_spt[i] = torch.tensor(perm[old.index(num)])

        for i in range(y_qry.size()[1]):
            num = int(y_qry[0][i].cpu().numpy())
            y_qry[0][i] = torch.tensor(perm[old.index(num)])
        #print('hi', y_qry.size())
        #print('y_spt', y_spt)
        #print('y_qry', y_qry)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        #print('heyyyy', x_spt.size(), y_spt.size(), x_qry.size(), y_qry.size())
        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 1 == 0:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
        if step % 300 == 0:
            correct = 0
            torch.save(maml.net, my_experiment.path + "learner.model")
            for img, target in iterator_test:
                with torch.no_grad():
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml.net(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    correct += torch.eq(pred_q, target).sum().item() / len(img)
            writer.add_scalar('/metatrain/test/classifier/accuracy',
                              correct / len(iterator_test), step)
            logger.info("Test Accuracy = %s",
                        str(correct / len(iterator_test)))
            correct = 0
            for img, target in iterator_train:
                with torch.no_grad():

                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml.net(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)
                    pred_q = (logits_q).argmax(dim=1)
                    correct += torch.eq(pred_q, target).sum().item() / len(img)

            logger.info("Train Accuracy = %s",
                        str(correct / len(iterator_train)))
            writer.add_scalar('/metatrain/train/classifier/accuracy',
                              correct / len(iterator_train), step)
Пример #3
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    my_experiment = experiment(args.name, args, "../results/", args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')
    logger.setLevel(logging.INFO)
    total_clases = 10

    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("vars." + str(temp))
    logger.info("Frozen layers = %s", " ".join(frozen_layers))

    final_results_all = []
    total_clases = args.schedule
    for tot_class in total_clases:
        lr_list = [
            0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001,
            0.000003
        ]
        for aoo in range(0, args.runs):

            keep = np.random.choice(list(range(20)), tot_class, replace=False)
            #

            dataset = imgnet.MiniImagenet(args.dataset_path,
                                          mode='test',
                                          elem_per_class=30,
                                          classes=keep,
                                          seed=aoo)

            dataset_test = imgnet.MiniImagenet(args.dataset_path,
                                               mode='test',
                                               elem_per_class=30,
                                               test=args.test,
                                               classes=keep,
                                               seed=aoo)

            # Iterators used for evaluation

            iterator = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=128,
                                                   shuffle=True,
                                                   num_workers=1)
            iterator_sorted = torch.utils.data.DataLoader(dataset,
                                                          batch_size=1,
                                                          shuffle=args.iid,
                                                          num_workers=1)

            #
            print(args)

            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')

            results_mem_size = {}

            for mem_size in [args.memory]:
                max_acc = -10
                max_lr = -10
                for lr in lr_list:

                    print(lr)
                    # for lr in [0.001, 0.0003, 0.0001, 0.00003, 0.00001]:
                    maml = torch.load(args.model, map_location='cpu')

                    if args.scratch:
                        config = mf.ModelFactory.get_model("na", args.dataset)
                        maml = learner.Learner(config)
                        # maml = MetaLearingClassification(args, config).to(device).net

                    maml = maml.to(device)

                    for name, param in maml.named_parameters():
                        param.learn = True

                    for name, param in maml.named_parameters():
                        # logger.info(name)
                        if name in frozen_layers:
                            # logger.info("Freeezing name %s", str(name))
                            param.learn = False
                            # logger.info(str(param.requires_grad))
                        else:
                            if args.reset:
                                w = nn.Parameter(torch.ones_like(param))
                                # logger.info("W shape = %s", str(len(w.shape)))
                                if len(w.shape) > 1:
                                    torch.nn.init.kaiming_normal_(w)
                                else:
                                    w = nn.Parameter(torch.zeros_like(param))
                                param.data = w
                                param.learn = True

                    frozen_layers = []
                    for temp in range(args.rln * 2):
                        frozen_layers.append("vars." + str(temp))

                    torch.nn.init.kaiming_normal_(maml.parameters()[-2])
                    w = nn.Parameter(torch.zeros_like(maml.parameters()[-1]))
                    maml.parameters()[-1].data = w

                    for n, a in maml.named_parameters():
                        n = n.replace(".", "_")
                        # logger.info("Name = %s", n)
                        if n == "vars_14":
                            w = nn.Parameter(torch.ones_like(a))
                            # logger.info("W shape = %s", str(w.shape))
                            torch.nn.init.kaiming_normal_(w)
                            a.data = w
                        if n == "vars_15":
                            w = nn.Parameter(torch.zeros_like(a))
                            a.data = w

                    correct = 0

                    for img, target in iterator:
                        with torch.no_grad():
                            img = img.to(device)
                            target = target.to(device)
                            logits_q = maml(img,
                                            vars=None,
                                            bn_training=False,
                                            feature=False)
                            pred_q = (logits_q).argmax(dim=1)
                            correct += torch.eq(pred_q,
                                                target).sum().item() / len(img)

                    logger.info("Pre-epoch accuracy %s",
                                str(correct / len(iterator)))

                    filter_list = [
                        "vars.0", "vars.1", "vars.2", "vars.3", "vars.4",
                        "vars.5"
                    ]

                    logger.info("Filter list = %s", ",".join(filter_list))
                    list_of_names = list(
                        map(
                            lambda x: x[1],
                            list(
                                filter(lambda x: x[0] not in filter_list,
                                       maml.named_parameters()))))

                    list_of_params = list(
                        filter(lambda x: x.learn, maml.parameters()))
                    list_of_names = list(
                        filter(lambda x: x[1].learn, maml.named_parameters()))
                    if args.scratch or args.no_freeze:
                        print("Empty filter list")
                        list_of_params = maml.parameters()
                    #
                    for x in list_of_names:
                        logger.info("Unfrozen layer = %s", str(x[0]))
                    opt = torch.optim.Adam(list_of_params, lr=lr)
                    res_sampler = rep.ReservoirSampler(mem_size)
                    for _ in range(0, args.epoch):
                        for img, y in iterator_sorted:
                            if mem_size > 0:
                                res_sampler.update_buffer(zip(img, y))
                                res_sampler.update_observations(len(img))
                                img = img.to(device)
                                y = y.to(device)
                                img2, y2 = res_sampler.sample_buffer(8)
                                img2 = img2.to(device)
                                y2 = y2.to(device)

                                img = torch.cat([img, img2], dim=0)
                                y = torch.cat([y, y2], dim=0)
                            else:
                                img = img.to(device)
                                y = y.to(device)

                            pred = maml(img)
                            opt.zero_grad()
                            loss = F.cross_entropy(pred, y)
                            loss.backward()
                            opt.step()

                    logger.info("Result after one epoch for LR = %f", lr)
                    correct = 0
                    for img, target in iterator:
                        img = img.to(device)
                        target = target.to(device)
                        logits_q = maml(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)

                        pred_q = (logits_q).argmax(dim=1)
                        # print("Pred=", pred_q)
                        # print("Target=", target)
                        correct += torch.eq(pred_q,
                                            target).sum().item() / len(img)

                    logger.info(str(correct / len(iterator)))
                    if (correct / len(iterator) > max_acc):
                        max_acc = correct / len(iterator)
                        max_lr = lr

                lr_list = [max_lr]
                results_mem_size[mem_size] = (max_acc, max_lr)
                logger.info("Final Max Result = %s", str(max_acc))
                writer.add_scalar('/finetune/best_' + str(aoo), max_acc,
                                  tot_class)
            final_results_all.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Final results = %s", str(results_mem_size))

            my_experiment.results["Final Results"] = final_results_all
            my_experiment.store_json()
            print("FINAL RESULTS = ", final_results_all)
    writer.close()
Пример #4
0
def main(args):
    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "../results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(64))

    # args.traj_classes = list(range(int(64 / 2), 963))

    dataset = imgnet.MiniImagenet(args.dataset_path, mode='train')

    dataset_test = imgnet.MiniImagenet(args.dataset_path, mode='test')

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset)

    config = mf.ModelFactory.get_model("na", "imagenet")

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = MetaLearingClassification(args, config).to(device)

    utils.freeze_layers(args.rln, maml)

    for step in range(args.steps):

        t1 = np.random.choice(args.classes, args.tasks, replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)

        # Evaluation during training for sanity checks
        if step % 40 == 39:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
        if step % 300 == 299:
            utils.log_accuracy(maml, my_experiment, iterator_test, device,
                               writer, step)
            utils.log_accuracy(maml, my_experiment, iterator_train, device,
                               writer, step)
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    my_experiment = experiment(args.name, args, "./evals/", args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")


    ver = 0

    while os.path.exists(args.modelX + "_" + str(ver)):
        ver += 1
        
    args.modelX = args.modelX + "_" + str(ver-1) + "/learner.model"
    
    logger = logging.getLogger('experiment')
    logger.setLevel(logging.INFO)
    total_clases = 10

    total_ff_vars = 2*(6 + 2 + args.num_extra_dense_layers)

    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("vars." + str(temp))
        

    for temp in range(args.rln_end * 2):
        frozen_layers.append("net.vars." + str(total_ff_vars - 1 - temp))
    #logger.info("Frozen layers = %s", " ".join(frozen_layers))

    #
    final_results_all = []
    
    total_clases = [5]

    if args.twentyclass:
        total_clases = [20]

    if args.twotask:
        total_clases = [2, 10]
    if args.fiftyclass:
        total_clases = [50]
    if args.tenclass:
        total_clases = [10]
    if args.fiveclass:
        total_clases = [5]       
    print('yooo', total_clases)
    for tot_class in total_clases:
        
        avg_perf = 0.0
        print('TOT_CLASS', tot_class)
        lr_list = [0]#[0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001]
        for aoo in range(0, args.runs):
            #print('run', aoo)
            keep = np.random.choice(list(range(650)), tot_class, replace=False)
            if args.dataset == "imagenet":
                keep = np.random.choice(list(range(20)), tot_class, replace=False)
                
                dataset = imgnet.MiniImagenet(args.imagenet_path, mode='test', elem_per_class=30, classes=keep, seed=aoo)

                dataset_test = imgnet.MiniImagenet(args.imagenet_path, mode='test', elem_per_class=30, classes=keep, test=args.test, seed=aoo)


                iterator = torch.utils.data.DataLoader(dataset_test, batch_size=128,
                                                       shuffle=True, num_workers=1)
                iterator_sorted = torch.utils.data.DataLoader(dataset, batch_size=1,
                                                       shuffle=False, num_workers=1)
            if args.dataset == "omniglot":

                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot", train=True, background=False), keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter_omni(dataset, False, classes=total_clases),
                    batch_size=1,
                    shuffle=False, num_workers=2)
                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot", train=not args.test, background=False), keep)
                iterator = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                       shuffle=False, num_workers=1)
            elif args.dataset == "CIFAR100":
                keep = np.random.choice(list(range(50, 100)), tot_class)
                dataset = utils.remove_classes(df.DatasetFactory.get_dataset(args.dataset, train=True), keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter(dataset, False, classes=tot_class),
                    batch_size=16,
                    shuffle=False, num_workers=2)
                dataset = utils.remove_classes(df.DatasetFactory.get_dataset(args.dataset, train=False), keep)
                iterator = torch.utils.data.DataLoader(dataset, batch_size=128,
                                                       shuffle=False, num_workers=1)
            # sampler = ts.MNISTSampler(list(range(0, total_clases)), dataset)
            #
            #print(args)

            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')

            results_mem_size = {}

            #print("LEN", len(iterator_sorted))
            for mem_size in [args.memory]:
                max_acc = -10
                max_lr = -10
                for lr in lr_list:
                    #torch.cuda.empty_cache()
                    #print(lr)
                    # for lr in [0.001, 0.0003, 0.0001, 0.00003, 0.00001]:
                    maml = torch.load(args.modelX, map_location='cpu')

                    if args.scratch:
                        config = mf.ModelFactory.get_model(args.model_type, args.dataset)
                        maml = learner.Learner(config, lr)
                        # maml = MetaLearingClassification(args, config).to(device).net

                    #maml.update_lr = lr
                    maml = maml.to(device)

                    for name, param in maml.named_parameters():
                        param.learn = True

                    for name, param in maml.named_parameters():
                        #if name.find("feedback_strength_vars") != -1:
                        #    print(name, param)
                        if name in frozen_layers:
                            # logger.info("Freeezing name %s", str(name))
                            param.learn = False
                            # logger.info(str(param.requires_grad))
                        else:
                            if args.reset:
                                w = nn.Parameter(torch.ones_like(param))
                                # logger.info("W shape = %s", str(len(w.shape)))
                                if len(w.shape) > 1:
                                    torch.nn.init.kaiming_normal_(w)
                                else:
                                    w = nn.Parameter(torch.zeros_like(param))
                                param.data = w
                                param.learn = True

                    frozen_layers = []
                    for temp in range(args.rln * 2):
                        frozen_layers.append("vars." + str(temp))

                    #torch.nn.init.kaiming_normal_(maml.parameters()[-2])
                    #w = nn.Parameter(torch.zeros_like(maml.parameters()[-1]))
                    #maml.parameters()[-1].data = w

                    
                    for n, a in maml.named_parameters():
                        n = n.replace(".", "_")
                        # logger.info("Name = %s", n)
                        if n == "vars_"+str(14+2*args.num_extra_dense_layers):
                            pass
                            #w = nn.Parameter(torch.ones_like(a))
                            # logger.info("W shape = %s", str(w.shape))
                            #torch.nn.init.kaiming_normal_(w)
                            #a.data = w
                        if n == "vars_"+str(15+2*args.num_extra_dense_layers):
                            pass
                            #w = nn.Parameter(torch.zeros_like(a))
                            #a.data = w
                            
                    #for fv in maml.feedback_vars:
                    #    w = nn.Parameter(torch.zeros_like(fv))
                    #    fv.data = w   
                        
                    #for fv in maml.feedback_strength_vars:
                    #    w = nn.Parameter(torch.ones_like(fv))
                    #    fv.data = w                        

                    correct = 0

                    for img, target in iterator:
                        #print('size', target.size())
                        target = torch.tensor(np.array([list(keep).index(int(target.cpu().numpy()[i])) for i in range(target.size()[0])]))
                        with torch.no_grad():
                            img = img.to(device)
                            target = target.to(device)
                            logits_q = maml(img, vars=None, bn_training=False, feature=False)
                            pred_q = (logits_q).argmax(dim=1)
                            correct += torch.eq(pred_q, target).sum().item() / len(img)

                    #logger.info("Pre-epoch accuracy %s", str(correct / len(iterator)))

                    filter_list = ["vars.0", "vars.1", "vars.2", "vars.3", "vars.4", "vars.5"]

                    #logger.info("Filter list = %s", ",".join(filter_list))
                    list_of_names = list(
                        map(lambda x: x[1], list(filter(lambda x: x[0] not in filter_list, maml.named_parameters()))))

                    list_of_params = list(filter(lambda x: x.learn, maml.parameters()))
                    list_of_names = list(filter(lambda x: x[1].learn, maml.named_parameters()))
                    if args.scratch or args.no_freeze:
                        print("Empty filter list")
                        list_of_params = maml.parameters()
                    #
                    #for x in list_of_names:
                    #    logger.info("Unfrozen layer = %s", str(x[0]))
                    opt = torch.optim.Adam(list_of_params, lr=lr)

                    fast_weights = None
                    if args.randomize_plastic_weights:
                        maml.randomize_plastic_weights()
                    if args.zero_plastic_weights:
                        maml.zero_plastic_weights()
                    res_sampler = rep.ReservoirSampler(mem_size)
                    iterator_sorted_new = []
                    iter_count = 0
                    for img, y in iterator_sorted:
                        
                        y = torch.tensor(np.array([list(keep).index(int(y.cpu().numpy()[i])) for i in range(y.size()[0])]))
                        if iter_count % 15 >= args.shots:
                            iter_count += 1
                            continue       
                        iterator_sorted_new.append((img, y))
                        iter_count += 1
                    iterator_sorted = []
                    perm = np.random.permutation(len(iterator_sorted_new))
                    for i in range(len(iterator_sorted_new)):
                        if args.iid:
                            iterator_sorted.append(iterator_sorted_new[perm[i]])
                        else:
                            iterator_sorted.append(iterator_sorted_new[i])

                    for iter in range(0, args.epoch):
                        iter_count = 0
                        imgs = []
                        ys = []
                            
                        for img, y in iterator_sorted:
                            
                            #print('iter count', iter_count)
                            #print('y is', y)
                            


                            #if iter_count % 15 >= args.shots:
                            #    iter_count += 1
                            #    continue
                            iter_count += 1
                            #with torch.no_grad():
                            if args.memory == 0:
                                  img = img.to(device)
                                  y = y.to(device)
                            else:
                                  res_sampler.update_buffer(zip(img, y))
                                  res_sampler.update_observations(len(img))
                                  img = img.to(device)
                                  y = y.to(device)
                                  img2, y2 = res_sampler.sample_buffer(8)
                                  img2 = img2.to(device)
                                  y2 = y2.to(device)
                                  img = torch.cat([img, img2], dim=0)
                                  y = torch.cat([y, y2], dim=0)
                                  #print('img size', img.size())

                            imgs.append(img)
                            ys.append(y)
                            if not args.batch_learning:
                                  logits = maml(img, vars=fast_weights)
                                  fast_weights = maml.getOjaUpdate(y, logits, fast_weights, hebbian=args.hebb)
                        if args.batch_learning:
                              y = torch.cat(ys, 0)
                              img = torch.cat(imgs, 0)
                              logits = maml(img, vars=fast_weights)
                              fast_weights = maml.getOjaUpdate(y, logits, fast_weights, hebbian=args.hebb)


                    #logger.info("Result after one epoch for LR = %f", lr)
                    correct = 0
                    for img, target in iterator:
                        target = torch.tensor(np.array([list(keep).index(int(target.cpu().numpy()[i])) for i in range(target.size()[0])]))
                        img = img.to(device)
                        target = target.to(device)
                        logits_q = maml(img, vars=fast_weights, bn_training=False, feature=False)

                        pred_q = (logits_q).argmax(dim=1)

                        correct += torch.eq(pred_q, target).sum().item() / len(img)

                    #logger.info(str(correct / len(iterator)))
                    if (correct / len(iterator) > max_acc):
                        max_acc = correct / len(iterator)
                        max_lr = lr
                        
                    del maml
                    #del maml
                    #del fast_weights

                lr_list = [max_lr]
                #print('result', max_acc)
                results_mem_size[mem_size] = (max_acc, max_lr)
                #logger.info("Final Max Result = %s", str(max_acc))
                writer.add_scalar('/finetune/best_' + str(aoo), max_acc, tot_class)
                avg_perf += max_acc / args.runs #TODO: change this if/when I ever use memory -- can't choose max memory size differently for each run!
                print('avg perf', avg_perf * args.runs / (1+aoo))
            final_results_all.append((tot_class, results_mem_size))
            #writer.add_scalar('performance', avg_perf, tot_class)
            #print("A=  ", results_mem_size)
            #logger.info("Final results = %s", str(results_mem_size))

            my_experiment.results["Final Results"] = final_results_all
            my_experiment.store_json()
            np.save('evals/final_results_'+args.orig_name+'.npy', final_results_all) 
            #print("FINAL RESULTS = ", final_results_all)
    writer.close()