コード例 #1
0
ファイル: eval_cl_omniglot.py プロジェクト: sebamenabar/mrcl
def run_eval(args):
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("Using device", device)
    else:
        device = torch.device("cpu")

    args.model_path = osp.join(args.exp_dir, args.model_name)

    with open(osp.join(args.exp_dir, "metadata.json"), "r") as f:
        metadata = json.load(f)
    resize = metadata["params"]["resize"]

    config = mf.ModelFactory.get_model(
        # "na", args["dataset"], output_dimension=1000, resize=resize
        "na",
        args.dataset,
        output_dimension=1000,
        resize=resize)

    maml = load_model(args, config)
    maml = maml.to(device)

    reset_zero = False
    maml.reset_vars(zero=reset_zero)

    data_train = df.DatasetFactory.get_dataset(args.dataset,
                                               train=True,
                                               background=True,
                                               path=args.data_path,
                                               resize=resize)
    data_test = df.DatasetFactory.get_dataset(args.dataset,
                                              train=False,
                                              background=True,
                                              path=args.data_path,
                                              resize=resize)

    to_keep = np.arange(664, 964)
    trainset = utils.remove_classes_omni(data_train, to_keep)
    valset = utils.remove_classes_omni(data_test, to_keep)

    print(len(trainset), len(valset))

    # for optimizer in ["sgd", "adam"]:
    for optimizer in ["adam"]:  # adam worked best
        for reset_zero in [True, False]:
            eval_loop(
                trainset,
                valset,
                maml,
                optimizer,
                10,
                reset_zero,
                device,
                args.lr_sweep_range,
                args.prefix,
            )
コード例 #2
0
ファイル: pretraining_omniglot.py プロジェクト: yxue3357/mrcl
def main(args):
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    np.random.seed(args.seed)

    my_experiment = experiment(args.name, args, "../results/")

    dataset = df.DatasetFactory.get_dataset(args.dataset)
    dataset_test = df.DatasetFactory.get_dataset(args.dataset, train=False)

    if args.dataset == "CUB":
        args.classes = list(range(100))
    if args.dataset == "CIFAR100":
        args.classes = list(range(50))

    if args.dataset == "omniglot":
        iterator = torch.utils.data.DataLoader(utils.remove_classes_omni(
            dataset, list(range(963))),
                                               batch_size=256,
                                               shuffle=True,
                                               num_workers=1)
        iterator_test = torch.utils.data.DataLoader(utils.remove_classes_omni(
            dataset_test, list(range(963))),
                                                    batch_size=256,
                                                    shuffle=True,
                                                    num_workers=1)
    else:
        iterator = torch.utils.data.DataLoader(utils.remove_classes(
            dataset, args.classes),
                                               batch_size=12,
                                               shuffle=True,
                                               num_workers=1)
        iterator_test = torch.utils.data.DataLoader(utils.remove_classes(
            dataset_test, args.classes),
                                                    batch_size=12,
                                                    shuffle=True,
                                                    num_workers=1)

    logger.info(str(args))

    config = mf.ModelFactory.get_model("na", args.dataset)

    maml = learner.Learner(config).to(device)

    opt = torch.optim.Adam(maml.parameters(), lr=args.lr)

    for e in range(args.epoch):
        correct = 0
        for img, y in tqdm(iterator):
            if e == 50:
                opt = torch.optim.Adam(maml.parameters(), lr=0.00001)
                logger.info("Changing LR from %f to %f", 0.0001, 0.00001)
            img = img.to(device)
            y = y.to(device)
            pred = maml(img)
            feature = maml(img, feature=True)
            loss_rep = torch.abs(feature).sum()

            opt.zero_grad()
            loss = F.cross_entropy(pred, y)
            # loss_rep.backward(retain_graph=True)
            # logger.info("L1 norm = %s", str(loss_rep.item()))
            loss.backward()
            opt.step()
            correct += (pred.argmax(1) == y).sum().float() / len(y)
        logger.info("Accuracy at epoch %d = %s", e,
                    str(correct / len(iterator)))

        correct = 0
        with torch.no_grad():
            for img, y in tqdm(iterator_test):

                img = img.to(device)
                y = y.to(device)
                pred = maml(img)
                feature = maml(img, feature=True)
                loss_rep = torch.abs(feature).sum()

                correct += (pred.argmax(1) == y).sum().float() / len(y)
            logger.info("Accuracy Test at epoch %d = %s", e,
                        str(correct / len(iterator_test)))

        torch.save(maml, my_experiment.path + "model.net")
コード例 #3
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    my_experiment = experiment(args.name, args, "/data5/jlindsey/continual/results", args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')
    logger.setLevel(logging.INFO)
    total_clases = 10

    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("vars." + str(temp))
    logger.info("Frozen layers = %s", " ".join(frozen_layers))

    #
    final_results_all = []
    total_clases = [10, 50, 75, 100, 150, 200]
    if args.twentyclass:
        total_clases = [20, 50]
    if args.fiveclass:
        total_clases = [5]
    for tot_class in total_clases:
        avg_perf = 0.0
        lr_list = [0.03]
        for aoo in range(0, args.runs):

            keep = np.random.choice(list(range(200)), tot_class, replace=False)
            
            print('keep', keep)

            if args.dataset == "omniglot":

                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot", train=True, background=False), keep)
                print('lenbefore', len(dataset.data))
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter_omni(dataset, False, classes=total_clases),
                    batch_size=1,
                    shuffle=args.iid, num_workers=2)
                print("LEN", len(iterator_sorted), len(dataset.data))
                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot", train=not args.test, background=False), keep)
                iterator = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                       shuffle=False, num_workers=1)
            elif args.dataset == "CIFAR100":
                keep = np.random.choice(list(range(50, 100)), tot_class)
                dataset = utils.remove_classes(df.DatasetFactory.get_dataset(args.dataset, train=True), keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter(dataset, False, classes=tot_class),
                    batch_size=16,
                    shuffle=args.iid, num_workers=2)
                dataset = utils.remove_classes(df.DatasetFactory.get_dataset(args.dataset, train=False), keep)
                iterator = torch.utils.data.DataLoader(dataset, batch_size=128,
                                                       shuffle=False, num_workers=1)
            # sampler = ts.MNISTSampler(list(range(0, total_clases)), dataset)
            #
            print(args)

            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')

            results_mem_size = {}
            
           

            for mem_size in [args.memory]:
                max_acc = -10
                max_lr = -10
                for lr in lr_list:

                    print(lr)
                    # for lr in [0.001, 0.0003, 0.0001, 0.00003, 0.00001]:
                    maml = torch.load(args.model, map_location='cpu')

                    if args.scratch:
                        config = mf.ModelFactory.get_model("na", args.dataset)
                        maml = learner.Learner(config)
                        # maml = MetaLearingClassification(args, config).to(device).net

                    maml = maml.to(device)

                    for name, param in maml.named_parameters():
                        param.learn = True

                    for name, param in maml.named_parameters():
                        # logger.info(name)
                        if name in frozen_layers:
                            # logger.info("Freeezing name %s", str(name))
                            param.learn = False
                            # logger.info(str(param.requires_grad))
                        else:
                            if args.reset:
                                w = nn.Parameter(torch.ones_like(param))
                                # logger.info("W shape = %s", str(len(w.shape)))
                                if len(w.shape) > 1:
                                    torch.nn.init.kaiming_normal_(w)
                                else:
                                    w = nn.Parameter(torch.zeros_like(param))
                                param.data = w
                                param.learn = True

                    frozen_layers = []
                    for temp in range(args.rln * 2):
                        frozen_layers.append("vars." + str(temp))

                        
                    '''
                    torch.nn.init.kaiming_normal_(maml.parameters()[-2])
                    w = nn.Parameter(torch.zeros_like(maml.parameters()[-1]))
                    maml.parameters()[-1].data = w

                    for n, a in maml.named_parameters():
                        n = n.replace(".", "_")
                        # logger.info("Name = %s", n)
                        if n == "vars_14":
                            w = nn.Parameter(torch.ones_like(a))
                            # logger.info("W shape = %s", str(w.shape))
                            torch.nn.init.kaiming_normal_(w)
                            a.data = w
                        if n == "vars_15":
                            w = nn.Parameter(torch.zeros_like(a))
                            a.data = w
                            
                    '''


                    correct = 0

                    for img, target in iterator:
                        with torch.no_grad():
                            img = img.to(device)
                            target = target.to(device)
                            logits_q = maml(img, vars=None, bn_training=False, feature=False)
                            pred_q = (logits_q).argmax(dim=1)
                            correct += torch.eq(pred_q, target).sum().item() / len(img)

                    logger.info("Pre-epoch accuracy %s", str(correct / len(iterator)))

                    filter_list = ["vars.0", "vars.1", "vars.2", "vars.3", "vars.4", "vars.5"]

                    logger.info("Filter list = %s", ",".join(filter_list))
                    list_of_names = list(
                        map(lambda x: x[1], list(filter(lambda x: x[0] not in filter_list, maml.named_parameters()))))

                    list_of_params = list(filter(lambda x: x.learn, maml.parameters()))
                    list_of_names = list(filter(lambda x: x[1].learn, maml.named_parameters()))
                    if args.scratch or args.no_freeze:
                        print("Empty filter list")
                        list_of_params = maml.parameters()
                    #
                    for x in list_of_names:
                        logger.info("Unfrozen layer = %s", str(x[0]))
                    opt = torch.optim.Adam(list_of_params, lr=lr)

                    fast_weights = maml.vars
                    if args.randomize_plastic_weights:
                        maml.randomize_plastic_weights()
                    if args.zero_plastic_weights:
                        maml.zero_plastic_weights()
                    for iter in range(0, args.epoch):
                        iter_count = 0
                        imgs = []
                        ys = []
                        for img, y in iterator_sorted:
                            #print(iter_count, y)
                            if iter_count % 15 >= args.shots:
                                iter_count += 1
                                continue
                            iter_count += 1
                            img = img.to(device)
                            y = y.to(device)
                            
                            imgs.append(img)
                            ys.append(y)
                            

                            if not args.batch_learning:
                                pred = maml(img, vars=fast_weights)
                                opt.zero_grad()
                                loss = F.cross_entropy(pred, y)
                                grad = torch.autograd.grad(loss, fast_weights)
                                # fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))

                                if args.plastic_update:
                                    fast_weights = list(
                                        map(lambda p: p[1] - p[0] * p[2] if p[1].learn else p[1], zip(grad, fast_weights, maml.vars_plasticity)))       
                                else:
                                    fast_weights = list(
                                        map(lambda p: p[1] - args.update_lr * p[0] if p[1].learn else p[1], zip(grad, fast_weights)))
                                for params_old, params_new in zip(maml.parameters(), fast_weights):
                                    params_new.learn = params_old.learn
                        if args.batch_learning:
                            y = torch.cat(ys, 0)
                            img = torch.cat(imgs, 0)
                            pred = maml(img, vars=fast_weights)
                            opt.zero_grad()
                            loss = F.cross_entropy(pred, y)
                            grad = torch.autograd.grad(loss, fast_weights)
                            # fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))

                            if args.plastic_update:
                                fast_weights = list(
                                    map(lambda p: p[1] - p[0] * p[2] if p[1].learn else p[1], zip(grad, fast_weights, maml.vars_plasticity)))       
                            else:
                                fast_weights = list(
                                    map(lambda p: p[1] - args.update_lr * p[0] if p[1].learn else p[1], zip(grad, fast_weights)))
                            for params_old, params_new in zip(maml.parameters(), fast_weights):
                                params_new.learn = params_old.learn
                            #loss.backward()
                            #opt.step()

                    logger.info("Result after one epoch for LR = %f", lr)
                    correct = 0
                    for img, target in iterator:
                        img = img.to(device)
                        target = target.to(device)
                        logits_q = maml(img, vars=fast_weights, bn_training=False, feature=False)

                        pred_q = (logits_q).argmax(dim=1)

                        correct += torch.eq(pred_q, target).sum().item() / len(img)

                    logger.info(str(correct / len(iterator)))
                    if (correct / len(iterator) > max_acc):
                        max_acc = correct / len(iterator)
                        max_lr = lr

                lr_list = [max_lr]
                results_mem_size[mem_size] = (max_acc, max_lr)
                avg_perf += max_acc / args.runs
                print('avg perf', avg_perf * args.runs / (1+aoo))
                logger.info("Final Max Result = %s", str(max_acc))
                writer.add_scalar('/finetune/best_' + str(aoo), max_acc, tot_class)
            final_results_all.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Final results = %s", str(results_mem_size))

            my_experiment.results["Final Results"] = final_results_all
            my_experiment.store_json()
            np.save('evals/final_results_'+args.orig_name+'.npy', final_results_all) 
            print("FINAL RESULTS = ", final_results_all)
    writer.close()
コード例 #4
0
ファイル: evaluate_omniglot.py プロジェクト: ybyangjing/mrcl
def main():
    p = class_parser_eval.Parser()
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)

    utils.set_seed(args['seed'])

    my_experiment = experiment(args['name'], args, "../results/", commit_changes=False, rank=0, seed=1)

    data_train = df.DatasetFactory.get_dataset("omniglot", train=True, background=False, path=args['path'])
    data_test = df.DatasetFactory.get_dataset("omniglot", train=False, background=False, path=args['path'])
    final_results_train = []
    final_results_test = []
    lr_sweep_results = []

    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device('cuda:' + str(gpu_to_use))
        logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
    else:
        device = torch.device('cpu')

    config = mf.ModelFactory.get_model("na", args['dataset'], output_dimension=1000)

    maml = load_model(args, config)
    maml = maml.to(device)

    args['schedule'] = [int(x) for x in args['schedule'].split(":")]
    no_of_classes_schedule = args['schedule']
    print(args["schedule"])
    for total_classes in no_of_classes_schedule:
        lr_sweep_range = [0.03, 0.01, 0.003,0.001, 0.0003, 0.0001, 0.00003, 0.00001]
        lr_all = []
        for lr_search_runs in range(0, 5):

            classes_to_keep = np.random.choice(list(range(650)), total_classes, replace=False)

            dataset = utils.remove_classes_omni(data_train, classes_to_keep)

            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset, False, classes=no_of_classes_schedule),
                batch_size=1,
                shuffle=args['iid'], num_workers=2)

            dataset = utils.remove_classes_omni(data_train, classes_to_keep)
            iterator_train = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                         shuffle=False, num_workers=1)

            max_acc = -1000
            for lr in lr_sweep_range:

                maml.reset_vars()

                opt = torch.optim.Adam(maml.get_adaptation_parameters(), lr=lr)

                train_iterator(iterator_sorted, device, maml, opt)

                correct = eval_iterator(iterator_train, device, maml)
                if (correct > max_acc):
                    max_acc = correct
                    max_lr = lr

            lr_all.append(max_lr)
            results_mem_size = (max_acc, max_lr)
            lr_sweep_results.append((total_classes, results_mem_size))

            my_experiment.results["LR Search Results"] = lr_sweep_results
            my_experiment.store_json()
            logger.debug("LR RESULTS = %s", str(lr_sweep_results))

        from scipy import stats
        best_lr = float(stats.mode(lr_all)[0][0])

        logger.info("BEST LR %s= ", str(best_lr))

        for current_run in range(0, args['runs']):

            classes_to_keep = np.random.choice(list(range(650)), total_classes, replace=False)

            dataset = utils.remove_classes_omni(data_train, classes_to_keep)

            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset, False, classes=no_of_classes_schedule),
                batch_size=1,
                shuffle=args['iid'], num_workers=2)

            dataset = utils.remove_classes_omni(data_test, classes_to_keep)
            iterator_test = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                        shuffle=False, num_workers=1)

            dataset = utils.remove_classes_omni(data_train, classes_to_keep)
            iterator_train = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                         shuffle=False, num_workers=1)

            lr = best_lr

            maml.reset_vars()

            opt = torch.optim.Adam(maml.get_adaptation_parameters(), lr=lr)

            train_iterator(iterator_sorted, device,maml, opt)

            logger.info("Result after one epoch for LR = %f", lr)

            correct = eval_iterator(iterator_train, device, maml)

            correct_test = eval_iterator(iterator_test, device, maml)

            results_mem_size = (correct, best_lr, "train")
            logger.info("Final Max Result train = %s", str(correct))
            final_results_train.append((total_classes, results_mem_size))

            results_mem_size = (correct_test, best_lr, "test")
            logger.info("Final Max Result test= %s", str(correct_test))
            final_results_test.append((total_classes, results_mem_size))

            my_experiment.results["Final Results"] = final_results_train
            my_experiment.results["Final Results Test"] = final_results_test
            my_experiment.store_json()
            logger.debug("FINAL RESULTS = %s", str(final_results_train))
            logger.debug("FINAL RESULTS = %s", str(final_results_test))
コード例 #5
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    my_experiment = experiment(args.name, args, "../results/", args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')
    logger.setLevel(logging.INFO)
    total_clases = 10

    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("vars." + str(temp))
    logger.info("Frozen layers = %s", " ".join(frozen_layers))

    #
    final_results_all = []
    total_clases = [10, 50, 75, 100, 150, 200]
    for tot_class in total_clases:
        lr_list = [0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001]
        for aoo in range(0, 20):

            keep = np.random.choice(list(range(200)), tot_class)

            if args.dataset == "omniglot":

                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot",
                                                  train=True,
                                                  background=False), keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter_omni(dataset,
                                               False,
                                               classes=total_clases),
                    batch_size=1,
                    shuffle=args.iid,
                    num_workers=2)
                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot",
                                                  train=not args.test,
                                                  background=False), keep)
                iterator = torch.utils.data.DataLoader(dataset,
                                                       batch_size=32,
                                                       shuffle=False,
                                                       num_workers=1)
            elif args.dataset == "CIFAR100":
                keep = np.random.choice(list(range(50, 100)), tot_class)
                dataset = utils.remove_classes(
                    df.DatasetFactory.get_dataset(args.dataset, train=True),
                    keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter(dataset, False, classes=tot_class),
                    batch_size=16,
                    shuffle=args.iid,
                    num_workers=2)
                dataset = utils.remove_classes(
                    df.DatasetFactory.get_dataset(args.dataset, train=False),
                    keep)
                iterator = torch.utils.data.DataLoader(dataset,
                                                       batch_size=128,
                                                       shuffle=False,
                                                       num_workers=1)
            # sampler = ts.MNISTSampler(list(range(0, total_clases)), dataset)
            #
            print(args)

            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')

            results_mem_size = {}

            for mem_size in [args.memory]:
                max_acc = -10
                max_lr = -10
                for lr in lr_list:

                    print(lr)
                    # for lr in [0.001, 0.0003, 0.0001, 0.00003, 0.00001]:
                    maml = torch.load(args.model, map_location='cpu')

                    if args.scratch:
                        config = mf.ModelFactory.get_model("na", args.dataset)
                        maml = learner.Learner(config)
                        # maml = MetaLearingClassification(args, config).to(device).net

                    maml = maml.to(device)

                    for name, param in maml.named_parameters():
                        param.learn = True

                    for name, param in maml.named_parameters():
                        # logger.info(name)
                        if name in frozen_layers:
                            # logger.info("Freeezing name %s", str(name))
                            param.learn = False
                            # logger.info(str(param.requires_grad))
                        else:
                            if args.reset:
                                w = nn.Parameter(torch.ones_like(param))
                                # logger.info("W shape = %s", str(len(w.shape)))
                                if len(w.shape) > 1:
                                    torch.nn.init.kaiming_normal_(w)
                                else:
                                    w = nn.Parameter(torch.zeros_like(param))
                                param.data = w
                                param.learn = True

                    frozen_layers = []
                    for temp in range(args.rln * 2):
                        frozen_layers.append("vars." + str(temp))

                    torch.nn.init.kaiming_normal_(maml.parameters()[-2])
                    w = nn.Parameter(torch.zeros_like(maml.parameters()[-1]))
                    maml.parameters()[-1].data = w

                    for n, a in maml.named_parameters():
                        n = n.replace(".", "_")
                        # logger.info("Name = %s", n)
                        if n == "vars_14":
                            w = nn.Parameter(torch.ones_like(a))
                            # logger.info("W shape = %s", str(w.shape))
                            torch.nn.init.kaiming_normal_(w)
                            a.data = w
                        if n == "vars_15":
                            w = nn.Parameter(torch.zeros_like(a))
                            a.data = w

                    correct = 0

                    for img, target in iterator:
                        with torch.no_grad():
                            img = img.to(device)
                            target = target.to(device)
                            logits_q = maml(img,
                                            vars=None,
                                            bn_training=False,
                                            feature=False)
                            pred_q = (logits_q).argmax(dim=1)
                            correct += torch.eq(pred_q,
                                                target).sum().item() / len(img)

                    logger.info("Pre-epoch accuracy %s",
                                str(correct / len(iterator)))

                    filter_list = [
                        "vars.0", "vars.1", "vars.2", "vars.3", "vars.4",
                        "vars.5"
                    ]

                    logger.info("Filter list = %s", ",".join(filter_list))
                    list_of_names = list(
                        map(
                            lambda x: x[1],
                            list(
                                filter(lambda x: x[0] not in filter_list,
                                       maml.named_parameters()))))

                    list_of_params = list(
                        filter(lambda x: x.learn, maml.parameters()))
                    list_of_names = list(
                        filter(lambda x: x[1].learn, maml.named_parameters()))
                    if args.scratch or args.no_freeze:
                        print("Empty filter list")
                        list_of_params = maml.parameters()
                    #
                    for x in list_of_names:
                        logger.info("Unfrozen layer = %s", str(x[0]))
                    opt = torch.optim.Adam(list_of_params, lr=lr)
                    import module.replay as rep
                    res_sampler = rep.ReservoirSampler(mem_size)
                    for _ in range(0, args.epoch):
                        for img, y in iterator_sorted:

                            if mem_size > 0:
                                res_sampler.update_buffer(zip(img, y))
                                res_sampler.update_observations(len(img))
                                img = img.to(device)
                                y = y.to(device)
                                img2, y2 = res_sampler.sample_buffer(16)
                                img2 = img2.to(device)
                                y2 = y2.to(device)

                                img = torch.cat([img, img2], dim=0)
                                y = torch.cat([y, y2], dim=0)
                            else:
                                img = img.to(device)
                                y = y.to(device)

                            pred = maml(img)
                            opt.zero_grad()
                            loss = F.cross_entropy(pred, y)
                            loss.backward()
                            opt.step()

                    logger.info("Result after one epoch for LR = %f", lr)
                    correct = 0
                    for img, target in iterator:
                        img = img.to(device)
                        target = target.to(device)
                        logits_q = maml(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)

                        pred_q = (logits_q).argmax(dim=1)

                        correct += torch.eq(pred_q,
                                            target).sum().item() / len(img)

                    logger.info(str(correct / len(iterator)))
                    if (correct / len(iterator) > max_acc):
                        max_acc = correct / len(iterator)
                        max_lr = lr

                lr_list = [max_lr]
                results_mem_size[mem_size] = (max_acc, max_lr)
                logger.info("Final Max Result = %s", str(max_acc))
                writer.add_scalar('/finetune/best_' + str(aoo), max_acc,
                                  tot_class)
                # quit()
            final_results_all.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Final results = %s", str(results_mem_size))

            my_experiment.results["Final Results"] = final_results_all
            my_experiment.store_json()
            print("FINAL RESULTS = ", final_results_all)
    writer.close()
コード例 #6
0
def run_eval(args):
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("Using device", device)
    else:
        device = torch.device("cpu")

    args.model_path = osp.join(args.exp_dir, args.model_name)

    with open(osp.join(args.exp_dir, "metadata.json"), "r") as f:
        metadata = json.load(f)
    resize = metadata["params"]["resize"]

    config = mf.ModelFactory.get_model(
        # "na", args["dataset"], output_dimension=1000, resize=resize
        "na",
        args.dataset,
        output_dimension=1000,
        resize=resize)

    maml = load_model(args, config)
    maml = maml.to(device)

    reset_zero = False
    maml.reset_vars(zero=reset_zero)

    data_train = df.DatasetFactory.get_dataset(args.dataset,
                                               train=True,
                                               background=True,
                                               path=args.data_path,
                                               resize=resize)
    data_test = df.DatasetFactory.get_dataset(args.dataset,
                                              train=False,
                                              background=True,
                                              path=args.data_path,
                                              resize=resize)

    to_keep = np.arange(664, 964)
    trainset = utils.remove_classes_omni(data_train, to_keep)
    valset = utils.remove_classes_omni(data_test, to_keep)

    print(len(trainset), len(valset))

    train_iterator = torch.utils.data.DataLoader(trainset,
                                                 batch_size=64,
                                                 shuffle=True,
                                                 num_workers=0,
                                                 drop_last=True)
    val_iterator = torch.utils.data.DataLoader(valset,
                                               batch_size=64,
                                               shuffle=True,
                                               num_workers=0,
                                               drop_last=False)

    for param in maml.parameters():
        param.requires_grad = False
    maml.parameters()[-2].requires_grad = True
    maml.parameters()[-1].requires_grad = True

    print("Features only")
    for optimizer in ["sgd", "adam"]:
        # for optimizer in ["adam"]: # adam worked best
        print(optimizer)
        # for reset_zero in [True, False]:
        for reset_zero in [True]:
            print(reset_zero)
            eval_loop(
                train_iterator,
                val_iterator,
                maml,
                optimizer,
                4,
                reset_zero,
                device,
                "features",
            )

    for param in maml.parameters():
        param.requires_grad = True
    # maml.parameters()[-2].requires_grad = True
    # maml.parameters()[-1].requires_grad = True

    print("full finetune")
    # for optimizer in ["sgd", "adam"]:
    for optimizer in ["adam"]:  # adam worker best for full finetune
        print(optimizer)
        # for reset_zero in [True, False]:
        for reset_zero in [True]:  # with zero worked
            print(reset_zero)
            eval_loop(
                train_iterator,
                val_iterator,
                maml,
                optimizer,
                4,
                reset_zero,
                device,
                "full_finetune",
            )
コード例 #7
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    my_experiment = experiment(args.name, args, "../results/", args.commit)

    logger = logging.getLogger('experiment')
    logger.setLevel(logging.INFO)

    total_clases = [900]

    keep = list(range(total_clases[0]))

    dataset = utils.remove_classes_omni(
        df.DatasetFactory.get_dataset("omniglot",
                                      train=True,
                                      path=args.data_path,
                                      all=True), keep)
    iterator_sorted = torch.utils.data.DataLoader(utils.iterator_sorter_omni(
        dataset, False, classes=total_clases),
                                                  batch_size=128,
                                                  shuffle=True,
                                                  num_workers=2)

    iterator = iterator_sorted

    print(args)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = torch.load(args.model, map_location='cpu')

    if args.scratch:
        config = mf.ModelFactory.get_model("na", args.dataset)
        maml = learner.Learner(config)

    maml = maml.to(device)

    reps = []
    counter = 0

    fig, axes = plt.subplots(9, 4)
    with torch.no_grad():
        for img, target in iterator:
            print(counter)

            img = img.to(device)
            target = target.to(device)
            # print(target)
            rep = maml(img, vars=None, bn_training=False, feature=True)
            rep = rep.view((-1, 32, 72)).detach().cpu().numpy()
            rep_instance = rep[0]
            if args.binary:
                rep_instance = (rep_instance > 0).astype(int)
            if args.max:
                rep = rep / np.max(rep)
            else:
                rep = (rep > 0).astype(int)
            if counter < 36:
                print("Adding plot")
                axes[int(counter / 4), counter % 4].imshow(rep_instance,
                                                           cmap=args.color)
                axes[int(counter / 4), counter % 4].set_yticklabels([])
                axes[int(counter / 4), counter % 4].set_xticklabels([])
                axes[int(counter / 4), counter % 4].set_aspect('equal')

            counter += 1
            reps.append(rep)

    plt.subplots_adjust(wspace=0.0, hspace=0.0)

    plt.savefig(my_experiment.path + "instance_" + str(counter) + ".pdf",
                format="pdf")
    plt.clf()

    rep = np.concatenate(reps)
    averge_activation = np.mean(rep, 0)
    plt.imshow(averge_activation, cmap=args.color)
    plt.colorbar()
    plt.clim(0, np.max(averge_activation))
    plt.savefig(my_experiment.path + "average_activation.pdf", format="pdf")
    plt.clf()
    instance_sparsity = np.mean((np.sum(np.sum(rep, 1), 1)) / (64 * 36))
    print("Instance sparisty = ", instance_sparsity)
    my_experiment.results["instance_sparisty"] = str(instance_sparsity)
    lifetime_sparsity = (np.sum(rep, 0) / len(rep)).flatten()
    mean_lifetime = np.mean(lifetime_sparsity)
    print("Lifetime sparsity = ", mean_lifetime)
    my_experiment.results["lifetime_sparisty"] = str(mean_lifetime)
    dead_neuros = float(np.sum(
        (lifetime_sparsity == 0).astype(int))) / len(lifetime_sparsity)
    print("Dead neurons percentange = ", dead_neuros)
    my_experiment.results["dead_neuros"] = str(dead_neuros)
    plt.hist(lifetime_sparsity)

    plt.savefig(my_experiment.path + "histogram.pdf", format="pdf")
    my_experiment.store_json()
コード例 #8
0
def main():
    p = class_parser_eval.Parser()
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)

    utils.set_seed(args['seed'])

    my_experiment = experiment(args['name'],
                               args,
                               "../results/",
                               commit_changes=False,
                               rank=0,
                               seed=1)

    final_results_all = []
    temp_result = []
    args['schedule'] = [int(x) for x in args['schedule'].split(":")]
    total_clases = args['schedule']
    print(args["schedule"])
    for tot_class in total_clases:
        print("Classes current step = ", tot_class)
        lr_list = [0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001]
        lr_all = []
        for lr_search in range(0, 5):

            keep = np.random.choice(list(range(650)), tot_class, replace=False)

            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=True,
                                              background=False,
                                              path=args['path']), keep)
            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset,
                                           False,
                                           classes=total_clases),
                batch_size=1,
                shuffle=args['iid'],
                num_workers=2)
            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=not args['test'],
                                              background=False,
                                              path=args['path']), keep)
            iterator = torch.utils.data.DataLoader(dataset,
                                                   batch_size=32,
                                                   shuffle=False,
                                                   num_workers=1)

            gpu_to_use = rank % args["gpus"]
            if torch.cuda.is_available():
                device = torch.device('cuda:' + str(gpu_to_use))
                logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
            else:
                device = torch.device('cpu')

            config = mf.ModelFactory.get_model("na",
                                               args['dataset'],
                                               output_dimension=1000)
            max_acc = -1000
            for lr in lr_list:

                print(lr)
                maml = load_model(args, config)
                maml = maml.to(device)

                opt = torch.optim.Adam(maml.get_adaptation_parameters(), lr=lr)

                for _ in range(0, 1):
                    for img, y in iterator_sorted:
                        img = img.to(device)
                        y = y.to(device)

                        pred = maml(img)
                        opt.zero_grad()
                        loss = F.cross_entropy(pred, y)
                        loss.backward()
                        opt.step()

                logger.info("Result after one epoch for LR = %f", lr)
                correct = 0
                for img, target in iterator:
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml(img)

                    pred_q = (logits_q).argmax(dim=1)

                    correct += torch.eq(pred_q, target).sum().item() / len(img)

                logger.info(str(correct / len(iterator)))
                if (correct / len(iterator) > max_acc):
                    max_acc = correct / len(iterator)
                    max_lr = lr

            lr_all.append(max_lr)
            logger.info("Final Max Result = %s", str(max_acc))
            results_mem_size = (max_acc, max_lr)
            temp_result.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Temp Results = %s", str(results_mem_size))

            my_experiment.results["Temp Results"] = temp_result
            my_experiment.store_json()
            print("LR RESULTS = ", temp_result)

        from scipy import stats
        best_lr = float(stats.mode(lr_all)[0][0])

        logger.info("BEST LR %s= ", str(best_lr))

        for aoo in range(0, args['runs']):

            keep = np.random.choice(list(range(650)), tot_class, replace=False)

            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=True,
                                              background=False), keep)
            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset,
                                           False,
                                           classes=total_clases),
                batch_size=1,
                shuffle=args['iid'],
                num_workers=2)
            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=not args['test'],
                                              background=False), keep)
            iterator = torch.utils.data.DataLoader(dataset,
                                                   batch_size=32,
                                                   shuffle=False,
                                                   num_workers=1)

            for mem_size in [args['memory']]:
                max_acc = -10
                max_lr = -10

                lr = best_lr

                # for lr in [0.001, 0.0003, 0.0001, 0.00003, 0.00001]:
                maml = load_model(args, config)
                maml = maml.to(device)

                opt = torch.optim.Adam(maml.get_adaptation_parameters(), lr=lr)

                for _ in range(0, 1):
                    for img, y in iterator_sorted:
                        img = img.to(device)
                        y = y.to(device)

                        pred = maml(img)
                        opt.zero_grad()
                        loss = F.cross_entropy(pred, y)
                        loss.backward()
                        opt.step()

                logger.info("Result after one epoch for LR = %f", lr)
                correct = 0
                for img, target in iterator:
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml(img,
                                    vars=None,
                                    bn_training=False,
                                    feature=False)

                    pred_q = (logits_q).argmax(dim=1)

                    correct += torch.eq(pred_q, target).sum().item() / len(img)

                logger.info(str(correct / len(iterator)))
                if (correct / len(iterator) > max_acc):
                    max_acc = correct / len(iterator)
                    max_lr = lr

                lr_list = [max_lr]
                results_mem_size = (max_acc, max_lr)
                logger.info("Final Max Result = %s", str(max_acc))
            final_results_all.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Final results = %s", str(results_mem_size))

            my_experiment.results["Final Results"] = final_results_all
            my_experiment.store_json()
            print("FINAL RESULTS = ", final_results_all)
コード例 #9
0
ファイル: eval_cl_omniglot.py プロジェクト: sebamenabar/mrcl
def eval_loop(
    # args,
    trainset,
    valset,
    learner,
    optimizer,
    runs,
    reset_zero,
    device,
    lr_sweep_range,
    prefix="",
):

    print(f"Eval loop {optimizer} reset zero {reset_zero}")

    final_results_train = []
    final_results_test = []
    lr_sweep_results = []

    # args['schedule'] = [int(x) for x in args['schedule'].split(":")]
    # args.schedule = [10, 25, 50, 75, 100, 200, 300]
    schedule = [25, 100, 200, 300]
    # no_of_classes_schedule = args["schedule"]
    no_of_classes_schedule = schedule
    # print(args.schedule)
    for total_classes in schedule:

        print(
            f"\n\n--------  Beggining schedule {total_classes} ---------- \n ")

        # lr_sweep_range = [0.03, 0.01, 0.003,0.001, 0.0003, 0.0001, 0.00003, 0.00001]
        lr_all = []
        for lr_search_runs in range(0, 5):

            classes_to_keep = np.random.choice(list(range(664, 964)),
                                               total_classes,
                                               replace=False)

            dataset = utils.remove_classes_omni(trainset, classes_to_keep)

            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(
                    dataset,
                    False,
                    # classes=no_of_classes_schedule,
                ),
                batch_size=1,
                # shuffle=args["iid"],
                shuffle=False,
                num_workers=1,
            )

            dataset = utils.remove_classes_omni(trainset, classes_to_keep)
            iterator_train = torch.utils.data.DataLoader(dataset,
                                                         batch_size=64,
                                                         shuffle=False,
                                                         num_workers=1)

            max_acc = -1000
            # lr_sweep_range = [
            #     # 0.03,
            #     # 0.01,
            #     # 0.003,
            #     # 0.001,
            #     0.0003,
            #     # 0.0001,
            #     # 0.00003,
            #     # 0.00001,
            #     # 5e-6,
            #     # 1e-6,
            # ]
            for lr in lr_sweep_range:

                learner.reset_vars(zero=reset_zero)

                if optimizer == "adam":
                    opt = torch.optim.Adam(learner.get_adaptation_parameters(),
                                           lr=lr)
                elif optimizer == "sgd":
                    opt = torch.optim.SGD(
                        learner.get_adaptation_parameters(),
                        lr=lr,
                        weight_decay=5e-4,
                        momentum=0.9,
                    )

                train_iterator(iterator_sorted, device, learner, opt)

                correct = eval_iterator(iterator_train, device, learner)
                if correct > max_acc:
                    max_acc = correct
                    max_lr = lr

                print(f"Accuracy LR {lr}: {correct}")

            lr_all.append(max_lr)
            results_mem_size = (max_acc, max_lr)
            lr_sweep_results.append((total_classes, results_mem_size))

            # my_experiment.results["LR Search Results"] = lr_sweep_results
            # my_experiment.store_json()
            # logger.debug("LR RESULTS = %s", str(lr_sweep_results))
        # print("SCHEDULE %d RESULTS = %s" % (total_classes, str(lr_sweep_results)))

        from scipy import stats

        best_lr = float(stats.mode(lr_all)[0][0])

        # logger.info("BEST LR %s= ", str(best_lr))
        print("BEST LR =%s " % str(best_lr))

        for current_run in range(0, runs):

            # classes_to_keep = np.random.choice(list(range(650)), total_classes, replace=False)
            classes_to_keep = np.random.choice(list(range(664, 964)),
                                               total_classes,
                                               replace=False)

            dataset = utils.remove_classes_omni(trainset, classes_to_keep)

            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset,
                                           False,
                                           classes=no_of_classes_schedule),
                batch_size=1,
                # shuffle=args["iid"],
                shuffle=False,
                num_workers=2,
            )

            dataset = utils.remove_classes_omni(valset, classes_to_keep)
            iterator_test = torch.utils.data.DataLoader(dataset,
                                                        batch_size=32,
                                                        shuffle=False,
                                                        num_workers=1)

            dataset = utils.remove_classes_omni(trainset, classes_to_keep)
            iterator_train = torch.utils.data.DataLoader(dataset,
                                                         batch_size=32,
                                                         shuffle=False,
                                                         num_workers=1)

            lr = best_lr

            learner.reset_vars(zero=reset_zero)

            if optimizer == "adam":
                opt = torch.optim.Adam(learner.get_adaptation_parameters(),
                                       lr=lr)
            elif optimizer == "sgd":
                opt = torch.optim.SGD(
                    learner.get_adaptation_parameters(),
                    lr=lr,
                    weight_decay=5e-4,
                    momentum=0.9,
                )

            train_iterator(iterator_sorted, device, learner, opt)

            # logger.info("Result after one epoch for LR = %f", lr)
            print("Result after one epoch for LR = %f" % lr)

            correct = eval_iterator(iterator_train, device, learner)

            correct_test = eval_iterator(iterator_test, device, learner)

            results_mem_size = (correct, best_lr, "train")
            # logger.info("Final Max Result train = %s", str(correct))
            print("Final Max Result train = %s" % str(correct))
            final_results_train.append((total_classes, results_mem_size))

            results_mem_size = (correct_test, best_lr, "test")
            # logger.info("Final Max Result test= %s", str(correct_test))
            print("Final Max Result test= %s" % str(correct_test))
            final_results_test.append((total_classes, results_mem_size))

            # my_experiment.results["Final Results"] = final_results_train
            # my_experiment.results["Final Results Test"] = final_results_test
            # my_experiment.store_json()
            # logger.debug("FINAL RESULTS = %s", str(final_results_train))
            # print("FINAL RESULTS = %s" % str(final_results_train))
            # logger.debug("FINAL RESULTS = %s", str(final_results_test))
            # print("FINAL RESULTS = %s" % str(final_results_test))
    print("Final results train")
    print(final_results_train)
    print("Final results test")
    print(final_results_test)

    with open(
            osp.join(args.exp_dir,
                     f"{prefix}cl_results_{optimizer}_zero_{reset_zero}.json"),
            "w") as f:
        json.dump(
            {
                "final_results_train": final_results_train,
                "final_results_test": final_results_test,
            },
            f,
        )
コード例 #10
0
ファイル: srnn_omniglot.py プロジェクト: yxue3357/mrcl
def main(args):
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    np.random.seed(args.seed)

    my_experiment = experiment(args.name, args, "../results/")

    dataset = df.DatasetFactory.get_dataset(args.dataset)

    if args.dataset == "CIFAR100":
        args.classes = list(range(50))

    if args.dataset == "omniglot":
        iterator = torch.utils.data.DataLoader(utils.remove_classes_omni(
            dataset, list(range(963))),
                                               batch_size=256,
                                               shuffle=True,
                                               num_workers=1)
    else:
        iterator = torch.utils.data.DataLoader(utils.remove_classes(
            dataset, args.classes),
                                               batch_size=256,
                                               shuffle=True,
                                               num_workers=1)

    logger.info(str(args))

    config = mf.ModelFactory.get_model("na", args.dataset)

    maml = learner.Learner(config).to(device)

    opt = torch.optim.Adam(maml.parameters(), lr=args.lr)

    for e in range(args.epoch):
        correct = 0
        for img, y in iterator:
            if e == 20:
                opt = torch.optim.Adam(maml.parameters(), lr=0.00001)
                logger.info("Changing LR from %f to %f", 0.0001, 0.00001)
            img = img.to(device)
            y = y.to(device)
            pred = maml(img)
            feature = F.relu(maml(img, feature=True))
            avg_feature = feature.mean(0)

            beta = args.beta
            beta_hat = avg_feature

            loss_rec = ((beta / (beta_hat + 0.0001)) -
                        torch.log(beta / (beta_hat + 0.0001)) - 1)
            # loss_rec = (beta / (beta_hat)
            loss_rec = loss_rec * (beta_hat > beta).float()

            loss_sparse = loss_rec

            if args.l1:
                loss_sparse = feature.mean(0)
            loss_sparse = loss_sparse.mean()

            opt.zero_grad()
            loss = F.cross_entropy(pred, y)
            loss_sparse.backward(retain_graph=True)
            loss.backward()
            opt.step()
            correct += (pred.argmax(1) == y).sum().float() / len(y)
        logger.info("Accuracy at epoch %d = %s", e,
                    str(correct / len(iterator)))
        torch.save(maml, my_experiment.path + "model.net")
コード例 #11
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    my_experiment = experiment(args.name, args, "../results/", args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')
    logger.setLevel(logging.INFO)
    total_clases = 10

    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("vars." + str(temp))
    logger.info("Frozen layers = %s", " ".join(frozen_layers))
    #for v in range(6):
    #    frozen_layers.append("vars_bn.{0}".format(v))

    final_results_all = []
    temp_result = []
    total_clases = args.schedule
    for tot_class in total_clases:
        lr_list = [
            0.001, 0.0006, 0.0004, 0.00035, 0.0003, 0.00025, 0.0002, 0.00015,
            0.0001, 0.00009, 0.00008, 0.00006, 0.00003, 0.00001
        ]
        lr_all = []
        for lr_search in range(10):

            keep = np.random.choice(list(range(650)), tot_class, replace=False)

            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=True,
                                              background=False,
                                              path=args.dataset_path), keep)
            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset,
                                           False,
                                           classes=total_clases),
                batch_size=1,
                shuffle=args.iid,
                num_workers=2)
            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=not args.test,
                                              background=False,
                                              path=args.dataset_path), keep)
            iterator = torch.utils.data.DataLoader(dataset,
                                                   batch_size=1,
                                                   shuffle=False,
                                                   num_workers=1)

            print(args)

            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')

            results_mem_size = {}

            for mem_size in [args.memory]:
                max_acc = -10
                max_lr = -10
                for lr in lr_list:

                    print(lr)
                    maml = torch.load(args.model, map_location='cpu')

                    if args.scratch:
                        config = mf.ModelFactory.get_model("OML", args.dataset)
                        maml = learner.Learner(config)
                        # maml = MetaLearingClassification(args, config).to(device).net

                    maml = maml.to(device)

                    for name, param in maml.named_parameters():
                        param.learn = True

                    for name, param in maml.named_parameters():
                        # logger.info(name)
                        if name in frozen_layers:
                            param.learn = False

                        else:
                            if args.reset:
                                w = nn.Parameter(torch.ones_like(param))
                                # logger.info("W shape = %s", str(len(w.shape)))
                                if len(w.shape) > 1:
                                    torch.nn.init.kaiming_normal_(w)
                                else:
                                    w = nn.Parameter(torch.zeros_like(param))
                                param.data = w
                                param.learn = True

                    frozen_layers = []
                    for temp in range(args.rln * 2):
                        frozen_layers.append("vars." + str(temp))

                    torch.nn.init.kaiming_normal_(maml.parameters()[-2])
                    w = nn.Parameter(torch.zeros_like(maml.parameters()[-1]))
                    maml.parameters()[-1].data = w

                    if args.neuromodulation:
                        weights2reset = ["vars_26"]
                        biases2reset = ["vars_27"]
                    else:
                        weights2reset = ["vars_14"]
                        biases2reset = ["vars_15"]

                    for n, a in maml.named_parameters():
                        n = n.replace(".", "_")

                        if n in weights2reset:

                            w = nn.Parameter(torch.ones_like(a)).to(device)
                            torch.nn.init.kaiming_normal_(w)
                            a.data = w

                        if n in biases2reset:

                            w = nn.Parameter(torch.zeros_like(a)).to(device)
                            a.data = w

                    filter_list = ["vars.{0}".format(v) for v in range(6)]

                    logger.info("Filter list = %s", ",".join(filter_list))

                    list_of_names = list(
                        map(
                            lambda x: x[1],
                            list(
                                filter(lambda x: x[0] not in filter_list,
                                       maml.named_parameters()))))

                    list_of_params = list(
                        filter(lambda x: x.learn, maml.parameters()))
                    list_of_names = list(
                        filter(lambda x: x[1].learn, maml.named_parameters()))

                    if args.scratch or args.no_freeze:
                        print("Empty filter list")
                        list_of_params = maml.parameters()

                    for x in list_of_names:
                        logger.info("Unfrozen layer = %s", str(x[0]))
                    opt = torch.optim.Adam(list_of_params, lr=lr)

                    for _ in range(0, args.epoch):
                        for img, y in iterator_sorted:
                            img = img.to(device)
                            y = y.to(device)

                            pred = maml(img)
                            opt.zero_grad()
                            loss = F.cross_entropy(pred, y)
                            loss.backward()
                            opt.step()

                    logger.info("Result after one epoch for LR = %f", lr)
                    correct = 0
                    for img, target in iterator:
                        img = img.to(device)
                        target = target.to(device)
                        logits_q = maml(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)

                        pred_q = (logits_q).argmax(dim=1)

                        correct += torch.eq(pred_q,
                                            target).sum().item() / len(img)

                    logger.info(str(correct / len(iterator)))
                    if (correct / len(iterator) > max_acc):
                        max_acc = correct / len(iterator)
                        max_lr = lr

                lr_all.append(max_lr)
                results_mem_size[mem_size] = (max_acc, max_lr)
                logger.info("Final Max Result = %s", str(max_acc))
                writer.add_scalar('/finetune/best_' + str(lr_search), max_acc,
                                  tot_class)
            temp_result.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Temp Results = %s", str(results_mem_size))

            my_experiment.results["Temp Results"] = temp_result
            my_experiment.store_json()
            print("LR RESULTS = ", temp_result)

        from scipy import stats
        best_lr = float(stats.mode(lr_all)[0][0])
        logger.info("BEST LR %s= ", str(best_lr))

        for aoo in range(args.runs):

            keep = np.random.choice(list(range(650)), tot_class, replace=False)

            if args.dataset == "omniglot":

                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot",
                                                  train=True,
                                                  background=False), keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter_omni(dataset,
                                               False,
                                               classes=total_clases),
                    batch_size=1,
                    shuffle=args.iid,
                    num_workers=2)
                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot",
                                                  train=not args.test,
                                                  background=False), keep)
                iterator = torch.utils.data.DataLoader(dataset,
                                                       batch_size=1,
                                                       shuffle=False,
                                                       num_workers=1)
            elif args.dataset == "CIFAR100":
                keep = np.random.choice(list(range(50, 100)), tot_class)
                dataset = utils.remove_classes(
                    df.DatasetFactory.get_dataset(args.dataset, train=True),
                    keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter(dataset, False, classes=tot_class),
                    batch_size=16,
                    shuffle=args.iid,
                    num_workers=2)
                dataset = utils.remove_classes(
                    df.DatasetFactory.get_dataset(args.dataset, train=False),
                    keep)
                iterator = torch.utils.data.DataLoader(dataset,
                                                       batch_size=128,
                                                       shuffle=False,
                                                       num_workers=1)
            print(args)

            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')

            results_mem_size = {}

            for mem_size in [args.memory]:
                max_acc = -10
                max_lr = -10

                lr = best_lr

                maml = torch.load(args.model, map_location='cpu')

                if args.scratch:
                    config = mf.ModelFactory.get_model("MRCL", args.dataset)
                    maml = learner.Learner(config)

                maml = maml.to(device)

                for name, param in maml.named_parameters():
                    param.learn = True

                for name, param in maml.named_parameters():
                    # logger.info(name)
                    if name in frozen_layers:
                        param.learn = False
                    else:
                        if args.reset:
                            w = nn.Parameter(torch.ones_like(param))
                            if len(w.shape) > 1:
                                torch.nn.init.kaiming_normal_(w)
                            else:
                                w = nn.Parameter(torch.zeros_like(param))
                            param.data = w
                            param.learn = True

                frozen_layers = []
                for temp in range(args.rln * 2):
                    frozen_layers.append("vars." + str(temp))

                torch.nn.init.kaiming_normal_(maml.parameters()[-2])
                w = nn.Parameter(torch.zeros_like(maml.parameters()[-1]))
                maml.parameters()[-1].data = w

                for n, a in maml.named_parameters():
                    n = n.replace(".", "_")
                    if args.neuromodulation:
                        weights2reset = ["vars_26"]
                        biases2reset = ["vars_27"]
                    else:
                        weights2reset = ["vars_14"]
                        biases2reset = ["vars_15"]

                    for n, a in maml.named_parameters():
                        n = n.replace(".", "_")

                        if n in weights2reset:

                            w = nn.Parameter(torch.ones_like(a)).to(device)
                            torch.nn.init.kaiming_normal_(w)
                            a.data = w

                        if n in biases2reset:

                            w = nn.Parameter(torch.zeros_like(a)).to(device)
                            a.data = w

                correct = 0
                for img, target in iterator:
                    with torch.no_grad():

                        img = img.to(device)
                        target = target.to(device)
                        logits_q = maml(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)
                        pred_q = (logits_q).argmax(dim=1)
                        correct += torch.eq(pred_q,
                                            target).sum().item() / len(img)

                logger.info("Pre-epoch accuracy %s",
                            str(correct / len(iterator)))

                filter_list = ["vars.{0}".format(v) for v in range(6)]

                logger.info("Filter list = %s", ",".join(filter_list))

                list_of_names = list(
                    map(
                        lambda x: x[1],
                        list(
                            filter(lambda x: x[0] not in filter_list,
                                   maml.named_parameters()))))

                list_of_params = list(
                    filter(lambda x: x.learn, maml.parameters()))
                list_of_names = list(
                    filter(lambda x: x[1].learn, maml.named_parameters()))
                if args.scratch or args.no_freeze:
                    print("Empty filter list")
                    list_of_params = maml.parameters()

                for x in list_of_names:
                    logger.info("Unfrozen layer = %s", str(x[0]))
                opt = torch.optim.Adam(list_of_params, lr=lr)

                for _ in range(0, args.epoch):
                    for img, y in iterator_sorted:
                        img = img.to(device)
                        y = y.to(device)
                        pred = maml(img)
                        opt.zero_grad()
                        loss = F.cross_entropy(pred, y)
                        loss.backward()
                        opt.step()

                logger.info("Result after one epoch for LR = %f", lr)

                correct = 0
                for img, target in iterator:
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml(img,
                                    vars=None,
                                    bn_training=False,
                                    feature=False)

                    pred_q = (logits_q).argmax(dim=1)

                    correct += torch.eq(pred_q, target).sum().item() / len(img)

                logger.info(str(correct / len(iterator)))
                if (correct / len(iterator) > max_acc):
                    max_acc = correct / len(iterator)
                    max_lr = lr

                lr_list = [max_lr]
                results_mem_size[mem_size] = (max_acc, max_lr)
                logger.info("Final Max Result = %s", str(max_acc))
                writer.add_scalar('/finetune/best_' + str(aoo), max_acc,
                                  tot_class)
            final_results_all.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Final results = %s", str(results_mem_size))

            my_experiment.results["Final Results"] = final_results_all
            my_experiment.store_json()
            print("FINAL RESULTS = ", final_results_all)

    writer.close()
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    my_experiment = experiment(args.name, args, "./evals/", args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")


    ver = 0

    while os.path.exists(args.modelX + "_" + str(ver)):
        ver += 1
        
    args.modelX = args.modelX + "_" + str(ver-1) + "/learner.model"
    
    logger = logging.getLogger('experiment')
    logger.setLevel(logging.INFO)
    total_clases = 10

    total_ff_vars = 2*(6 + 2 + args.num_extra_dense_layers)

    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("vars." + str(temp))
        

    for temp in range(args.rln_end * 2):
        frozen_layers.append("net.vars." + str(total_ff_vars - 1 - temp))
    #logger.info("Frozen layers = %s", " ".join(frozen_layers))

    #
    final_results_all = []
    
    total_clases = [5]

    if args.twentyclass:
        total_clases = [20]

    if args.twotask:
        total_clases = [2, 10]
    if args.fiftyclass:
        total_clases = [50]
    if args.tenclass:
        total_clases = [10]
    if args.fiveclass:
        total_clases = [5]       
    print('yooo', total_clases)
    for tot_class in total_clases:
        
        avg_perf = 0.0
        print('TOT_CLASS', tot_class)
        lr_list = [0]#[0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001]
        for aoo in range(0, args.runs):
            #print('run', aoo)
            keep = np.random.choice(list(range(650)), tot_class, replace=False)
            if args.dataset == "imagenet":
                keep = np.random.choice(list(range(20)), tot_class, replace=False)
                
                dataset = imgnet.MiniImagenet(args.imagenet_path, mode='test', elem_per_class=30, classes=keep, seed=aoo)

                dataset_test = imgnet.MiniImagenet(args.imagenet_path, mode='test', elem_per_class=30, classes=keep, test=args.test, seed=aoo)


                iterator = torch.utils.data.DataLoader(dataset_test, batch_size=128,
                                                       shuffle=True, num_workers=1)
                iterator_sorted = torch.utils.data.DataLoader(dataset, batch_size=1,
                                                       shuffle=False, num_workers=1)
            if args.dataset == "omniglot":

                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot", train=True, background=False), keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter_omni(dataset, False, classes=total_clases),
                    batch_size=1,
                    shuffle=False, num_workers=2)
                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot", train=not args.test, background=False), keep)
                iterator = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                       shuffle=False, num_workers=1)
            elif args.dataset == "CIFAR100":
                keep = np.random.choice(list(range(50, 100)), tot_class)
                dataset = utils.remove_classes(df.DatasetFactory.get_dataset(args.dataset, train=True), keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter(dataset, False, classes=tot_class),
                    batch_size=16,
                    shuffle=False, num_workers=2)
                dataset = utils.remove_classes(df.DatasetFactory.get_dataset(args.dataset, train=False), keep)
                iterator = torch.utils.data.DataLoader(dataset, batch_size=128,
                                                       shuffle=False, num_workers=1)
            # sampler = ts.MNISTSampler(list(range(0, total_clases)), dataset)
            #
            #print(args)

            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')

            results_mem_size = {}

            #print("LEN", len(iterator_sorted))
            for mem_size in [args.memory]:
                max_acc = -10
                max_lr = -10
                for lr in lr_list:
                    #torch.cuda.empty_cache()
                    #print(lr)
                    # for lr in [0.001, 0.0003, 0.0001, 0.00003, 0.00001]:
                    maml = torch.load(args.modelX, map_location='cpu')

                    if args.scratch:
                        config = mf.ModelFactory.get_model(args.model_type, args.dataset)
                        maml = learner.Learner(config, lr)
                        # maml = MetaLearingClassification(args, config).to(device).net

                    #maml.update_lr = lr
                    maml = maml.to(device)

                    for name, param in maml.named_parameters():
                        param.learn = True

                    for name, param in maml.named_parameters():
                        #if name.find("feedback_strength_vars") != -1:
                        #    print(name, param)
                        if name in frozen_layers:
                            # logger.info("Freeezing name %s", str(name))
                            param.learn = False
                            # logger.info(str(param.requires_grad))
                        else:
                            if args.reset:
                                w = nn.Parameter(torch.ones_like(param))
                                # logger.info("W shape = %s", str(len(w.shape)))
                                if len(w.shape) > 1:
                                    torch.nn.init.kaiming_normal_(w)
                                else:
                                    w = nn.Parameter(torch.zeros_like(param))
                                param.data = w
                                param.learn = True

                    frozen_layers = []
                    for temp in range(args.rln * 2):
                        frozen_layers.append("vars." + str(temp))

                    #torch.nn.init.kaiming_normal_(maml.parameters()[-2])
                    #w = nn.Parameter(torch.zeros_like(maml.parameters()[-1]))
                    #maml.parameters()[-1].data = w

                    
                    for n, a in maml.named_parameters():
                        n = n.replace(".", "_")
                        # logger.info("Name = %s", n)
                        if n == "vars_"+str(14+2*args.num_extra_dense_layers):
                            pass
                            #w = nn.Parameter(torch.ones_like(a))
                            # logger.info("W shape = %s", str(w.shape))
                            #torch.nn.init.kaiming_normal_(w)
                            #a.data = w
                        if n == "vars_"+str(15+2*args.num_extra_dense_layers):
                            pass
                            #w = nn.Parameter(torch.zeros_like(a))
                            #a.data = w
                            
                    #for fv in maml.feedback_vars:
                    #    w = nn.Parameter(torch.zeros_like(fv))
                    #    fv.data = w   
                        
                    #for fv in maml.feedback_strength_vars:
                    #    w = nn.Parameter(torch.ones_like(fv))
                    #    fv.data = w                        

                    correct = 0

                    for img, target in iterator:
                        #print('size', target.size())
                        target = torch.tensor(np.array([list(keep).index(int(target.cpu().numpy()[i])) for i in range(target.size()[0])]))
                        with torch.no_grad():
                            img = img.to(device)
                            target = target.to(device)
                            logits_q = maml(img, vars=None, bn_training=False, feature=False)
                            pred_q = (logits_q).argmax(dim=1)
                            correct += torch.eq(pred_q, target).sum().item() / len(img)

                    #logger.info("Pre-epoch accuracy %s", str(correct / len(iterator)))

                    filter_list = ["vars.0", "vars.1", "vars.2", "vars.3", "vars.4", "vars.5"]

                    #logger.info("Filter list = %s", ",".join(filter_list))
                    list_of_names = list(
                        map(lambda x: x[1], list(filter(lambda x: x[0] not in filter_list, maml.named_parameters()))))

                    list_of_params = list(filter(lambda x: x.learn, maml.parameters()))
                    list_of_names = list(filter(lambda x: x[1].learn, maml.named_parameters()))
                    if args.scratch or args.no_freeze:
                        print("Empty filter list")
                        list_of_params = maml.parameters()
                    #
                    #for x in list_of_names:
                    #    logger.info("Unfrozen layer = %s", str(x[0]))
                    opt = torch.optim.Adam(list_of_params, lr=lr)

                    fast_weights = None
                    if args.randomize_plastic_weights:
                        maml.randomize_plastic_weights()
                    if args.zero_plastic_weights:
                        maml.zero_plastic_weights()
                    res_sampler = rep.ReservoirSampler(mem_size)
                    iterator_sorted_new = []
                    iter_count = 0
                    for img, y in iterator_sorted:
                        
                        y = torch.tensor(np.array([list(keep).index(int(y.cpu().numpy()[i])) for i in range(y.size()[0])]))
                        if iter_count % 15 >= args.shots:
                            iter_count += 1
                            continue       
                        iterator_sorted_new.append((img, y))
                        iter_count += 1
                    iterator_sorted = []
                    perm = np.random.permutation(len(iterator_sorted_new))
                    for i in range(len(iterator_sorted_new)):
                        if args.iid:
                            iterator_sorted.append(iterator_sorted_new[perm[i]])
                        else:
                            iterator_sorted.append(iterator_sorted_new[i])

                    for iter in range(0, args.epoch):
                        iter_count = 0
                        imgs = []
                        ys = []
                            
                        for img, y in iterator_sorted:
                            
                            #print('iter count', iter_count)
                            #print('y is', y)
                            


                            #if iter_count % 15 >= args.shots:
                            #    iter_count += 1
                            #    continue
                            iter_count += 1
                            #with torch.no_grad():
                            if args.memory == 0:
                                  img = img.to(device)
                                  y = y.to(device)
                            else:
                                  res_sampler.update_buffer(zip(img, y))
                                  res_sampler.update_observations(len(img))
                                  img = img.to(device)
                                  y = y.to(device)
                                  img2, y2 = res_sampler.sample_buffer(8)
                                  img2 = img2.to(device)
                                  y2 = y2.to(device)
                                  img = torch.cat([img, img2], dim=0)
                                  y = torch.cat([y, y2], dim=0)
                                  #print('img size', img.size())

                            imgs.append(img)
                            ys.append(y)
                            if not args.batch_learning:
                                  logits = maml(img, vars=fast_weights)
                                  fast_weights = maml.getOjaUpdate(y, logits, fast_weights, hebbian=args.hebb)
                        if args.batch_learning:
                              y = torch.cat(ys, 0)
                              img = torch.cat(imgs, 0)
                              logits = maml(img, vars=fast_weights)
                              fast_weights = maml.getOjaUpdate(y, logits, fast_weights, hebbian=args.hebb)


                    #logger.info("Result after one epoch for LR = %f", lr)
                    correct = 0
                    for img, target in iterator:
                        target = torch.tensor(np.array([list(keep).index(int(target.cpu().numpy()[i])) for i in range(target.size()[0])]))
                        img = img.to(device)
                        target = target.to(device)
                        logits_q = maml(img, vars=fast_weights, bn_training=False, feature=False)

                        pred_q = (logits_q).argmax(dim=1)

                        correct += torch.eq(pred_q, target).sum().item() / len(img)

                    #logger.info(str(correct / len(iterator)))
                    if (correct / len(iterator) > max_acc):
                        max_acc = correct / len(iterator)
                        max_lr = lr
                        
                    del maml
                    #del maml
                    #del fast_weights

                lr_list = [max_lr]
                #print('result', max_acc)
                results_mem_size[mem_size] = (max_acc, max_lr)
                #logger.info("Final Max Result = %s", str(max_acc))
                writer.add_scalar('/finetune/best_' + str(aoo), max_acc, tot_class)
                avg_perf += max_acc / args.runs #TODO: change this if/when I ever use memory -- can't choose max memory size differently for each run!
                print('avg perf', avg_perf * args.runs / (1+aoo))
            final_results_all.append((tot_class, results_mem_size))
            #writer.add_scalar('performance', avg_perf, tot_class)
            #print("A=  ", results_mem_size)
            #logger.info("Final results = %s", str(results_mem_size))

            my_experiment.results["Final Results"] = final_results_all
            my_experiment.store_json()
            np.save('evals/final_results_'+args.orig_name+'.npy', final_results_all) 
            #print("FINAL RESULTS = ", final_results_all)
    writer.close()