Exemplo n.º 1
0
def main():
    p = class_parser.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)


    utils.set_seed(args['seed'])

    my_experiment = experiment(args['name'], args, "../results/", commit_changes=False, rank=0, seed=1)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args['classes'] = list(range(963))

    args['traj_classes'] = list(range(int(963/2), 963))


    dataset = df.DatasetFactory.get_dataset(args['dataset'], background=True, train=True,path=args["path"], all=True)
    dataset_test = df.DatasetFactory.get_dataset(args['dataset'], background=True, train=False, path=args["path"], all=True)

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test, batch_size=5,
                                                shuffle=True, num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset, batch_size=5,
                                                 shuffle=True, num_workers=1)

    sampler = ts.SamplerFactory.get_sampler(args['dataset'], args['classes'], dataset, dataset_test)

    config = mf.ModelFactory.get_model("na", args['dataset'], output_dimension=1000)

    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device('cuda:' + str(gpu_to_use))
        logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
    else:
        device = torch.device('cpu')

    maml = MetaLearingClassification(args, config).to(device)


    for step in range(args['steps']):

        t1 = np.random.choice(args['traj_classes'], args['tasks'], replace=False)
        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))
        d_rand_iterator = sampler.get_complete_iterator()
        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(d_traj_iterators, d_rand_iterator,
                                                               steps=args['update_step'], reset=not args['no_reset'])
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)
Exemplo n.º 2
0
def main():
    p = params.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    _args = p.parse_args()
    # rank = p.parse_known_args()[0].rank
    rank = _args.rank
    # all_args = vars(p.parse_known_args()[0])
    print("All args = ", _args)

    args = utils.get_run(vars(_args), rank)

    utils.set_seed(args["seed"])

    if args["log_root"]:
        log_root = osp.join("./results", args["log_root"]) + "/"
    else:
        log_root = osp.join("./results/")

    my_experiment = experiment(
        args["name"],
        args,
        log_root,
        commit_changes=False,
        rank=0,
        seed=args["seed"],
    )
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device("cuda:" + str(gpu_to_use))
        logger.info("Using gpu : %s", "cuda:" + str(gpu_to_use))
    else:
        device = torch.device("cpu")

    print("Train dataset")
    dataset = df.DatasetFactory.get_dataset(
        args["dataset"],
        background=True,
        train=True,
        path=args["path"],
        all=True,
        resize=args["resize"],
        augment=args["augment"],
        prefetch_gpu=args["prefetch_gpu"],
    )
    print("Val dataset")
    val_dataset = df.DatasetFactory.get_dataset(
        args["dataset"],
        background=True,
        train=True,
        path=args["path"],
        all=True,
        resize=args["resize"],
        prefetch_gpu=args["prefetch_gpu"],
        #  augment=args["augment"],
    )

    train_labels = np.arange(664)
    # class_labels = np.array(dataset.targets)
    class_labels = np.array(np.asarray(torch.as_tensor(dataset.targets, device="cpu")))
    labels_mapping = {
        tl: (class_labels == tl).astype(int).nonzero()[0] for tl in train_labels
    }
    train_indices = [tl[:15] for tl in labels_mapping.values()]
    val_indices = [tl[15:] for tl in labels_mapping.values()]
    train_indices = [i for sublist in train_indices for i in sublist]
    val_indices = [i for sublist in val_indices for i in sublist]

    # indices = np.zeros_like(class_labels)
    # for a in train_labels:
    #     indices = indices + (class_labels == a).astype(int)
    # val_indices = (indices == 0).astype(int)
    # indices = np.nonzero(indices)[0]
    trainset = torch.utils.data.Subset(dataset, train_indices)

    # print(indices)
    print("Total samples:", len(class_labels))
    print("Train samples:", len(train_indices))
    print("Val samples:", len(val_indices))

    #  val_labels = np.arange(664)
    # class_labels = np.array(dataset.targets)
    # val_indices = np.zeros_like(class_labels)
    # for a in train_labels:
    #     val_indices = val_indices + (class_labels != a).astype(int)
    # val_indices = np.nonzero(val_indices)[0]
    valset = torch.utils.data.Subset(val_dataset, val_indices)

    train_iterator = torch.utils.data.DataLoader(
        trainset,
        batch_size=64,
        shuffle=True,
        num_workers=0,
        drop_last=True,
    )
    val_iterator = torch.utils.data.DataLoader(
        valset,
        batch_size=256,
        shuffle=True,
        num_workers=0,
        drop_last=False,
    )

    logger.info("Args:")
    logger.info(str(vars(_args)))
    logger.info(str(args))

    config = mf.ModelFactory.get_model("na", args["dataset"], resize=args["resize"])

    maml = learner.Learner(config).to(device)

    for k, v in maml.named_parameters():
        print(k, v.requires_grad)

    # opt = torch.optim.Adam(maml.parameters(), lr=args["lr"])
    opt = torch.optim.SGD(
        maml.parameters(),
        lr=args["lr"],
        momentum=0.9,
        weight_decay=5e-4,
    )
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        opt,
        milestones=_args.schedule,
        gamma=0.1,
    )

    best_val_acc = 0

    # print(learner)
    # print(learner.eval(False))

    histories = {
        "train": {"acc": [], "loss": [], "step": []},
        "val": {"acc": [], "loss": [], "step": []},
    }

    for e in range(args["epoch"]):
        correct = 0
        total_loss = 0.0
        maml.train()
        for img, y in tqdm(train_iterator):
            img = img.to(device)
            y = y.to(device)
            pred = maml(img)

            opt.zero_grad()
            loss = F.cross_entropy(pred, y.long())
            loss.backward()
            opt.step()
            correct += (pred.argmax(1) == y).float().mean()
            total_loss += loss
        correct = correct.item()
        total_loss = total_loss.item()
        scheduler.step()

        val_correct = 0
        val_total_loss = 0.0
        maml.eval()
        for img, y in tqdm(val_iterator):
            img = img.to(device)
            y = y.to(device)
            with torch.no_grad():
                pred = maml(img)

                opt.zero_grad()
                loss = F.cross_entropy(pred, y.long())
                # loss.backward()
                # opt.step()
                val_correct += (pred.argmax(1) == y).sum().float()
                val_total_loss += loss * y.size(0)
        val_correct = val_correct.item()
        val_total_loss = val_total_loss.item()
        val_acc = val_correct / len(val_indices)
        val_loss = val_total_loss / len(val_indices)

        train_correct = correct / len(train_iterator)
        train_loss = total_loss / len(train_iterator)

        logger.info("Accuracy at epoch %d = %s", e, str(train_correct))
        logger.info("Loss at epoch %d = %s", e, str(train_loss))
        logger.info("Val Accuracy at epoch %d = %s", e, str(val_acc))
        logger.info("Val Loss at epoch %d = %s", e, str(val_loss))

        histories["train"]["acc"].append(train_correct)
        histories["train"]["loss"].append(train_loss)
        histories["val"]["acc"].append(val_acc)
        histories["val"]["loss"].append(val_loss)
        histories["train"]["step"].append(e + 1)
        histories["val"]["step"].append(e + 1)

        writer.add_scalar(
            "/train/accuracy",
            train_correct,
            e + 1,
        )
        writer.add_scalar(
            "/train/loss",
            train_loss,
            e + 1,
        )
        writer.add_scalar(
            "/val/accuracy",
            val_acc,
            e + 1,
        )
        writer.add_scalar(
            "/train/loss",
            val_loss,
            e + 1,
        )

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            logger.info(f"\nNew best validation accuracy: {str(best_val_acc)}\n")
            torch.save(maml, my_experiment.path + "model_best.net")

    with open(my_experiment.path + "results.json", "w") as f:
        json.dump(histories, f)
    torch.save(maml, my_experiment.path + "last_model.net")
from copy import deepcopy

gamma = 0.9

logger = logging.getLogger('experiment')

p = reg_parser.Parser()
total_seeds = len(p.parse_known_args()[0].seed)
rank = p.parse_known_args()[0].run
all_args = vars(p.parse_known_args()[0])

args = utils.get_run(all_args, rank)

my_experiment = experiment(args["name"],
                           args,
                           args["output_dir"],
                           sql=True,
                           run=int(rank / total_seeds),
                           seed=total_seeds)

my_experiment.results["all_args"] = all_args
my_experiment.make_table("error_table", {
    "run": 0,
    "step": 0,
    "error": 0.0
}, ("run", "step"))
my_experiment.make_table(
    "predictions", {
        "run": 0,
        "step": 0,
        "x0": 0,
        "x1": 0,
Exemplo n.º 4
0
def main():
    p = reg_parser.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])

    args = utils.get_run(all_args, rank)

    my_experiment = experiment(args["name"], args, args["output_dir"], commit_changes=False,
                               rank=int(rank / total_seeds),
                               seed=total_seeds)

    my_experiment.results["all_args"] = all_args



    logger = logging.getLogger('experiment')

    gradient_error_list = []
    gradient_alignment_list = []

    for seed in range(args["runs"]):
        utils.set_seed(args["seed"] + seed + seed*args["seed"])
        n = Recurrent_Network(50, args['columns'], args["width"],
                              args["sparsity"])
        error_grad_mc = 0

        rnn_state = torch.zeros(args['columns'])
        n.reset_TH()

        for ind in range(50):

            x = torch.bernoulli(torch.zeros(1, 50) + 0.5)

            _, _, grads = n.forward(x, rnn_state, grad=True, retain_graph=False, bptt=False)

            value_prediction, rnn_state, _ = n.forward(x, rnn_state, grad=False,
                                                       retain_graph=False, bptt=True)

            n.update_TH(grads)

            target_random = random.random() * 100 - 50
            real_error = (0.5) * (target_random - value_prediction) ** 2
            error_grad_mc += real_error

            n.accumulate_gradients(target_random, value_prediction, hidden_state=rnn_state)

        grads = torch.autograd.grad(error_grad_mc, n.parameters())

        counter = 0
        total_sum = 0
        positive_sum = 0
        dif = 0

        for named, param in n.named_parameters():
            # if "prediction" in named:
            #     counter+=1
            #     continue
            # print(named)
            # print(grads[counter], n.grads[named])
            dif += torch.abs(n.grads[named] - grads[counter]).sum()
            positive = ((n.grads[named] * grads[counter]) > 1e-10).float().sum()
            total = positive + ((n.grads[named] * grads[counter]) < - 1e-10).float().sum()
            total_sum += total
            positive_sum += positive

            counter += 1

        logger.error("Difference = %s", (float(dif) / total_sum).item())
        gradient_error_list.append( (float(dif) / total_sum).item())
        gradient_alignment_list.append(str(float(positive_sum) / float(total_sum)))
        logger.error("Grad alignment %s", str(float(positive_sum) / float(total_sum)))




        my_experiment.add_result("abs_error", str(gradient_error_list))
        my_experiment.add_result("alignment", str(gradient_alignment_list))

        my_experiment.store_json()
Exemplo n.º 5
0
from experiment.experiment import experiment
from utils.g_and_t_utils import *

p = params.PwbParameters()
total_seeds = len(p.parse_known_args()[0].seed)
rank = p.parse_known_args()[0].rank
all_args = vars(p.parse_known_args()[0])
print("All hyperparameters = ", all_args)

flags = utils.get_run(vars(p.parse_known_args()[0]), rank)

utils.set_seed(flags["seed"])

my_experiment = experiment(flags["name"],
                           flags,
                           flags['output_dir'],
                           commit_changes=False,
                           rank=int(rank / total_seeds),
                           seed=total_seeds)

my_experiment.results["all_args"] = all_args

logger = logging.getLogger('experiment')

logger.info("Selected hyperparameters %s", str(flags))

device = torch.device('cpu')

final_train_accs = []
final_test_accs = []

mnist = datasets.MNIST('~/datasets/mnist', train=True, download=True)
Exemplo n.º 6
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    my_experiment = experiment(args.name, args, "/data5/jlindsey/continual/results", args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')
    logger.setLevel(logging.INFO)
    total_clases = 10

    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("vars." + str(temp))
    logger.info("Frozen layers = %s", " ".join(frozen_layers))

    #
    final_results_all = []
    total_clases = [10, 50, 75, 100, 150, 200]
    if args.twentyclass:
        total_clases = [20, 50]
    if args.fiveclass:
        total_clases = [5]
    for tot_class in total_clases:
        avg_perf = 0.0
        lr_list = [0.03]
        for aoo in range(0, args.runs):

            keep = np.random.choice(list(range(200)), tot_class, replace=False)
            
            print('keep', keep)

            if args.dataset == "omniglot":

                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot", train=True, background=False), keep)
                print('lenbefore', len(dataset.data))
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter_omni(dataset, False, classes=total_clases),
                    batch_size=1,
                    shuffle=args.iid, num_workers=2)
                print("LEN", len(iterator_sorted), len(dataset.data))
                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot", train=not args.test, background=False), keep)
                iterator = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                       shuffle=False, num_workers=1)
            elif args.dataset == "CIFAR100":
                keep = np.random.choice(list(range(50, 100)), tot_class)
                dataset = utils.remove_classes(df.DatasetFactory.get_dataset(args.dataset, train=True), keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter(dataset, False, classes=tot_class),
                    batch_size=16,
                    shuffle=args.iid, num_workers=2)
                dataset = utils.remove_classes(df.DatasetFactory.get_dataset(args.dataset, train=False), keep)
                iterator = torch.utils.data.DataLoader(dataset, batch_size=128,
                                                       shuffle=False, num_workers=1)
            # sampler = ts.MNISTSampler(list(range(0, total_clases)), dataset)
            #
            print(args)

            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')

            results_mem_size = {}
            
           

            for mem_size in [args.memory]:
                max_acc = -10
                max_lr = -10
                for lr in lr_list:

                    print(lr)
                    # for lr in [0.001, 0.0003, 0.0001, 0.00003, 0.00001]:
                    maml = torch.load(args.model, map_location='cpu')

                    if args.scratch:
                        config = mf.ModelFactory.get_model("na", args.dataset)
                        maml = learner.Learner(config)
                        # maml = MetaLearingClassification(args, config).to(device).net

                    maml = maml.to(device)

                    for name, param in maml.named_parameters():
                        param.learn = True

                    for name, param in maml.named_parameters():
                        # logger.info(name)
                        if name in frozen_layers:
                            # logger.info("Freeezing name %s", str(name))
                            param.learn = False
                            # logger.info(str(param.requires_grad))
                        else:
                            if args.reset:
                                w = nn.Parameter(torch.ones_like(param))
                                # logger.info("W shape = %s", str(len(w.shape)))
                                if len(w.shape) > 1:
                                    torch.nn.init.kaiming_normal_(w)
                                else:
                                    w = nn.Parameter(torch.zeros_like(param))
                                param.data = w
                                param.learn = True

                    frozen_layers = []
                    for temp in range(args.rln * 2):
                        frozen_layers.append("vars." + str(temp))

                        
                    '''
                    torch.nn.init.kaiming_normal_(maml.parameters()[-2])
                    w = nn.Parameter(torch.zeros_like(maml.parameters()[-1]))
                    maml.parameters()[-1].data = w

                    for n, a in maml.named_parameters():
                        n = n.replace(".", "_")
                        # logger.info("Name = %s", n)
                        if n == "vars_14":
                            w = nn.Parameter(torch.ones_like(a))
                            # logger.info("W shape = %s", str(w.shape))
                            torch.nn.init.kaiming_normal_(w)
                            a.data = w
                        if n == "vars_15":
                            w = nn.Parameter(torch.zeros_like(a))
                            a.data = w
                            
                    '''


                    correct = 0

                    for img, target in iterator:
                        with torch.no_grad():
                            img = img.to(device)
                            target = target.to(device)
                            logits_q = maml(img, vars=None, bn_training=False, feature=False)
                            pred_q = (logits_q).argmax(dim=1)
                            correct += torch.eq(pred_q, target).sum().item() / len(img)

                    logger.info("Pre-epoch accuracy %s", str(correct / len(iterator)))

                    filter_list = ["vars.0", "vars.1", "vars.2", "vars.3", "vars.4", "vars.5"]

                    logger.info("Filter list = %s", ",".join(filter_list))
                    list_of_names = list(
                        map(lambda x: x[1], list(filter(lambda x: x[0] not in filter_list, maml.named_parameters()))))

                    list_of_params = list(filter(lambda x: x.learn, maml.parameters()))
                    list_of_names = list(filter(lambda x: x[1].learn, maml.named_parameters()))
                    if args.scratch or args.no_freeze:
                        print("Empty filter list")
                        list_of_params = maml.parameters()
                    #
                    for x in list_of_names:
                        logger.info("Unfrozen layer = %s", str(x[0]))
                    opt = torch.optim.Adam(list_of_params, lr=lr)

                    fast_weights = maml.vars
                    if args.randomize_plastic_weights:
                        maml.randomize_plastic_weights()
                    if args.zero_plastic_weights:
                        maml.zero_plastic_weights()
                    for iter in range(0, args.epoch):
                        iter_count = 0
                        imgs = []
                        ys = []
                        for img, y in iterator_sorted:
                            #print(iter_count, y)
                            if iter_count % 15 >= args.shots:
                                iter_count += 1
                                continue
                            iter_count += 1
                            img = img.to(device)
                            y = y.to(device)
                            
                            imgs.append(img)
                            ys.append(y)
                            

                            if not args.batch_learning:
                                pred = maml(img, vars=fast_weights)
                                opt.zero_grad()
                                loss = F.cross_entropy(pred, y)
                                grad = torch.autograd.grad(loss, fast_weights)
                                # fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))

                                if args.plastic_update:
                                    fast_weights = list(
                                        map(lambda p: p[1] - p[0] * p[2] if p[1].learn else p[1], zip(grad, fast_weights, maml.vars_plasticity)))       
                                else:
                                    fast_weights = list(
                                        map(lambda p: p[1] - args.update_lr * p[0] if p[1].learn else p[1], zip(grad, fast_weights)))
                                for params_old, params_new in zip(maml.parameters(), fast_weights):
                                    params_new.learn = params_old.learn
                        if args.batch_learning:
                            y = torch.cat(ys, 0)
                            img = torch.cat(imgs, 0)
                            pred = maml(img, vars=fast_weights)
                            opt.zero_grad()
                            loss = F.cross_entropy(pred, y)
                            grad = torch.autograd.grad(loss, fast_weights)
                            # fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))

                            if args.plastic_update:
                                fast_weights = list(
                                    map(lambda p: p[1] - p[0] * p[2] if p[1].learn else p[1], zip(grad, fast_weights, maml.vars_plasticity)))       
                            else:
                                fast_weights = list(
                                    map(lambda p: p[1] - args.update_lr * p[0] if p[1].learn else p[1], zip(grad, fast_weights)))
                            for params_old, params_new in zip(maml.parameters(), fast_weights):
                                params_new.learn = params_old.learn
                            #loss.backward()
                            #opt.step()

                    logger.info("Result after one epoch for LR = %f", lr)
                    correct = 0
                    for img, target in iterator:
                        img = img.to(device)
                        target = target.to(device)
                        logits_q = maml(img, vars=fast_weights, bn_training=False, feature=False)

                        pred_q = (logits_q).argmax(dim=1)

                        correct += torch.eq(pred_q, target).sum().item() / len(img)

                    logger.info(str(correct / len(iterator)))
                    if (correct / len(iterator) > max_acc):
                        max_acc = correct / len(iterator)
                        max_lr = lr

                lr_list = [max_lr]
                results_mem_size[mem_size] = (max_acc, max_lr)
                avg_perf += max_acc / args.runs
                print('avg perf', avg_perf * args.runs / (1+aoo))
                logger.info("Final Max Result = %s", str(max_acc))
                writer.add_scalar('/finetune/best_' + str(aoo), max_acc, tot_class)
            final_results_all.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Final results = %s", str(results_mem_size))

            my_experiment.results["Final Results"] = final_results_all
            my_experiment.store_json()
            np.save('evals/final_results_'+args.orig_name+'.npy', final_results_all) 
            print("FINAL RESULTS = ", final_results_all)
    writer.close()
Exemplo n.º 7
0
def main(args):
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    np.random.seed(args.seed)

    my_experiment = experiment(args.name, args, "./results/")

    args.classes = list(range(64))

    # args.traj_classes = list(range(int(64 / 2), 963))

    dataset = imgnet.MiniImagenet(args.dataset_path, mode='train')

    dataset_test = imgnet.MiniImagenet(args.dataset_path, mode='test')

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator = torch.utils.data.DataLoader(dataset,
                                           batch_size=128,
                                           shuffle=True,
                                           num_workers=1)

    #
    logger.info(str(args))

    config = mf.ModelFactory.get_model("na", args.dataset)

    maml = learner.Learner(config).to(device)

    opt = torch.optim.Adam(maml.parameters(), lr=args.lr)

    for e in range(args.epoch):
        correct = 0
        for img, y in tqdm(iterator):
            if e == 50:
                opt = torch.optim.Adam(maml.parameters(), lr=0.00001)
                logger.info("Changing LR from %f to %f", 0.0001, 0.00001)
            img = img.to(device)
            y = y.to(device)
            pred = maml(img)
            feature = maml(img, feature=True)
            loss_rep = torch.abs(feature).sum()

            opt.zero_grad()
            loss = F.cross_entropy(pred, y)
            # loss_rep.backward(retain_graph=True)
            # logger.info("L1 norm = %s", str(loss_rep.item()))
            loss.backward()
            opt.step()
            correct += (pred.argmax(1) == y).sum().float() / len(y)
        logger.info("Accuracy at epoch %d = %s", e,
                    str(correct / len(iterator)))

        # correct = 0
        # with torch.no_grad():
        #     for img, y in tqdm(iterator_test):
        #
        #         img = img.to(device)
        #         y = y.to(device)
        #         pred = maml(img)
        #         feature = maml(img, feature=True)
        #         loss_rep = torch.abs(feature).sum()
        #
        #         correct += (pred.argmax(1) == y).sum().float() / len(y)
        #     logger.info("Accuracy Test at epoch %d = %s", e, str(correct / len(iterator_test)))

        torch.save(maml,
                   my_experiment.path + "baseline_pretraining_imagenet.net")
Exemplo n.º 8
0
def main():
    p = class_parser_eval.Parser()
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)

    utils.set_seed(args['seed'])

    my_experiment = experiment(args['name'], args, "../results/", commit_changes=False, rank=0, seed=1)

    data_train = df.DatasetFactory.get_dataset("omniglot", train=True, background=False, path=args['path'])
    data_test = df.DatasetFactory.get_dataset("omniglot", train=False, background=False, path=args['path'])
    final_results_train = []
    final_results_test = []
    lr_sweep_results = []

    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device('cuda:' + str(gpu_to_use))
        logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
    else:
        device = torch.device('cpu')

    config = mf.ModelFactory.get_model("na", args['dataset'], output_dimension=1000)

    maml = load_model(args, config)
    maml = maml.to(device)

    args['schedule'] = [int(x) for x in args['schedule'].split(":")]
    no_of_classes_schedule = args['schedule']
    print(args["schedule"])
    for total_classes in no_of_classes_schedule:
        lr_sweep_range = [0.03, 0.01, 0.003,0.001, 0.0003, 0.0001, 0.00003, 0.00001]
        lr_all = []
        for lr_search_runs in range(0, 5):

            classes_to_keep = np.random.choice(list(range(650)), total_classes, replace=False)

            dataset = utils.remove_classes_omni(data_train, classes_to_keep)

            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset, False, classes=no_of_classes_schedule),
                batch_size=1,
                shuffle=args['iid'], num_workers=2)

            dataset = utils.remove_classes_omni(data_train, classes_to_keep)
            iterator_train = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                         shuffle=False, num_workers=1)

            max_acc = -1000
            for lr in lr_sweep_range:

                maml.reset_vars()

                opt = torch.optim.Adam(maml.get_adaptation_parameters(), lr=lr)

                train_iterator(iterator_sorted, device, maml, opt)

                correct = eval_iterator(iterator_train, device, maml)
                if (correct > max_acc):
                    max_acc = correct
                    max_lr = lr

            lr_all.append(max_lr)
            results_mem_size = (max_acc, max_lr)
            lr_sweep_results.append((total_classes, results_mem_size))

            my_experiment.results["LR Search Results"] = lr_sweep_results
            my_experiment.store_json()
            logger.debug("LR RESULTS = %s", str(lr_sweep_results))

        from scipy import stats
        best_lr = float(stats.mode(lr_all)[0][0])

        logger.info("BEST LR %s= ", str(best_lr))

        for current_run in range(0, args['runs']):

            classes_to_keep = np.random.choice(list(range(650)), total_classes, replace=False)

            dataset = utils.remove_classes_omni(data_train, classes_to_keep)

            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset, False, classes=no_of_classes_schedule),
                batch_size=1,
                shuffle=args['iid'], num_workers=2)

            dataset = utils.remove_classes_omni(data_test, classes_to_keep)
            iterator_test = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                        shuffle=False, num_workers=1)

            dataset = utils.remove_classes_omni(data_train, classes_to_keep)
            iterator_train = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                         shuffle=False, num_workers=1)

            lr = best_lr

            maml.reset_vars()

            opt = torch.optim.Adam(maml.get_adaptation_parameters(), lr=lr)

            train_iterator(iterator_sorted, device,maml, opt)

            logger.info("Result after one epoch for LR = %f", lr)

            correct = eval_iterator(iterator_train, device, maml)

            correct_test = eval_iterator(iterator_test, device, maml)

            results_mem_size = (correct, best_lr, "train")
            logger.info("Final Max Result train = %s", str(correct))
            final_results_train.append((total_classes, results_mem_size))

            results_mem_size = (correct_test, best_lr, "test")
            logger.info("Final Max Result test= %s", str(correct_test))
            final_results_test.append((total_classes, results_mem_size))

            my_experiment.results["Final Results"] = final_results_train
            my_experiment.results["Final Results Test"] = final_results_test
            my_experiment.store_json()
            logger.debug("FINAL RESULTS = %s", str(final_results_train))
            logger.debug("FINAL RESULTS = %s", str(final_results_test))
Exemplo n.º 9
0
def main(args):
    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "../results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger("experiment")

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(963))

    if torch.cuda.is_available():
        device = torch.device("cuda")
        use_cuda = True
    else:
        device = torch.device("cpu")
        use_cuda = False

    dataset = df.DatasetFactory.get_dataset(
        args.dataset,
        background=True,
        train=True,
        all=True,
        prefetch_gpu=args.prefetch_gpu,
        device=device,
    )
    dataset_test = dataset
    # dataset_test = df.DatasetFactory.get_dataset(
    #     args.dataset, background=True, train=False, all=True
    # )

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    sampler = ts.SamplerFactory.get_sampler(
        args.dataset,
        args.classes,
        dataset,
        dataset_test,
        prefetch_gpu=args.prefetch_gpu,
        use_cuda=use_cuda,
    )

    config = mf.ModelFactory.get_model(args.treatment, args.dataset)

    maml = MetaLearingClassification(args, config, args.treatment).to(device)

    if args.checkpoint:
        checkpoint = torch.load(args.saved_model, map_location="cpu")

        for idx in range(len(checkpoint)):
            maml.net.parameters()[idx].data = checkpoint.parameters()[idx].data

    maml = maml.to(device)

    utils.freeze_layers(args.rln, maml)

    for step in range(args.steps):

        t1 = np.random.choice(args.classes, args.tasks, replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset,
        )
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = (
                x_spt.cuda(),
                y_spt.cuda(),
                x_qry.cuda(),
                y_qry.cuda(),
            )

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)  # , args.tasks)

        # Evaluation during training for sanity checks
        if step % 40 == 0:
            # writer.add_scalar('/metatrain/train/accuracy', accs, step)
            logger.info("step: %d \t training acc %s", step, str(accs))
        if step % 100 == 0 or step == 19999:
            torch.save(maml.net, args.model_name)
        if step % 2000 == 0 and step != 0:
            utils.log_accuracy(maml, my_experiment, iterator_test, device,
                               writer, step)
            utils.log_accuracy(maml, my_experiment, iterator_train, device,
                               writer, step)
Exemplo n.º 10
0
def main():
    p = reg_parser.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    args = utils.get_run(vars(p.parse_known_args()[0]), rank)
    utils.set_seed(args["seed"])

    my_experiment = experiment(args["name"],
                               args,
                               "../results/",
                               commit_changes=False,
                               rank=int(rank / total_seeds),
                               seed=total_seeds)

    my_experiment.results["all_args"] = all_args
    writer = SummaryWriter(my_experiment.path + "tensorboard")
    logger = logging.getLogger('experiment')
    pprint(args)

    tasks = list(range(400))

    sampler = ts.SamplerFactory.get_sampler("Sin",
                                            tasks,
                                            None,
                                            capacity=args["capacity"] + 1)

    model_config = mf.ModelFactory.get_model(args["model"],
                                             "Sin",
                                             input_dimension=args["capacity"] +
                                             1,
                                             output_dimension=1,
                                             width=args["width"])
    context_backbone_config = None
    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device('cuda:' + str(gpu_to_use))
        logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
    else:
        device = torch.device('cpu')

    metalearner = MetaLearnerRegression(args, model_config,
                                        context_backbone_config).to(device)
    tmp = filter(lambda x: x.requires_grad, metalearner.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logger.info('Total trainable tensors: %d', num)
    #
    running_meta_loss = 0
    adaptation_loss = 0
    loss_history = []
    adaptation_loss_history = []
    adaptation_running_loss_history = []
    meta_steps_counter = 0
    LOG_INTERVAL = 50
    for step in range(args["epoch"]):
        if step % LOG_INTERVAL == 0:
            logger.debug("####\t STEP %d \t####", step)
        net = metalearner.net
        meta_steps_counter += 1
        t1 = np.random.choice(tasks, args["tasks"], replace=False)
        iterators = []

        for t in t1:
            iterators.append(sampler.sample_task([t]))

        x_traj_meta, y_traj_meta, x_rand_meta, y_rand_meta = utils.construct_set(
            iterators, sampler, steps=args["update_step"])
        x_traj_meta, y_traj_meta = x_traj_meta.view(-1, 1,
                                                    51), y_traj_meta.view(
                                                        -1, 1, 2)

        if torch.cuda.is_available():
            x_traj_meta, y_traj_meta, x_rand_meta, y_rand_meta = x_traj_meta.to(
                device), y_traj_meta.to(device), x_rand_meta.to(
                    device), y_rand_meta.to(device)

        meta_loss = metalearner(x_traj_meta, y_traj_meta, x_rand_meta,
                                y_rand_meta)
        loss_history.append(meta_loss[-1].detach().cpu().item())

        running_meta_loss = running_meta_loss * 0.97 + 0.03 * meta_loss[
            -1].detach().cpu()
        running_meta_loss_fixed = running_meta_loss / (1 -
                                                       (0.97**
                                                        (meta_steps_counter)))
        writer.add_scalar('/metatrain/train/accuracy',
                          meta_loss[-1].detach().cpu(), meta_steps_counter)
        writer.add_scalar('/metatrain/train/runningaccuracy',
                          running_meta_loss_fixed, meta_steps_counter)

        if step % LOG_INTERVAL == 0:
            if running_meta_loss > 0:
                logger.info("Running meta loss = %f",
                            running_meta_loss_fixed.item())

            with torch.no_grad():
                t1 = np.random.choice(tasks, args["tasks"], replace=False)

                iterators = []
                for t in t1:
                    iterators.append(sampler.sample_task([t]))

                x_traj, y_traj, x_rand, y_rand = utils.construct_set(
                    iterators, sampler, steps=args["update_step"])
                x_traj, y_traj = x_traj.view(-1, 1, 51), y_traj.view(-1, 1, 2)
                if torch.cuda.is_available():
                    x_traj, y_traj, x_rand, y_rand = x_traj.to(
                        device), y_traj.to(device), x_rand.to(
                            device), y_rand.to(device)

                logits = net(x_rand[0], vars=None)
                logits_select = []
                assert y_rand[0, :, 1].sum() == 0
                for no, val in enumerate(y_rand[0, :, 1].long()):
                    logits_select.append(logits[no, val])
                logits = torch.stack(logits_select).unsqueeze(1)

                current_adaptation_loss = F.mse_loss(
                    logits, y_rand[0, :, 0].unsqueeze(1))
                adaptation_loss_history.append(
                    current_adaptation_loss.detach().item())
                adaptation_loss = adaptation_loss * 0.97 + current_adaptation_loss.detach(
                ).cpu().item() * 0.03
                adaptation_loss_fixed = adaptation_loss / (1 -
                                                           (0.97**(step + 1)))
                adaptation_running_loss_history.append(adaptation_loss_fixed)

                logger.info("Adaptation loss = %f", current_adaptation_loss)

                if step % LOG_INTERVAL == 0:
                    logger.info("Running adaptation loss = %f",
                                adaptation_loss_fixed)
                writer.add_scalar('/learn/test/adaptation_loss',
                                  current_adaptation_loss, step)

        if (step + 1) % (LOG_INTERVAL * 500) == 0:
            if not args["no_save"]:
                torch.save(metalearner.net, my_experiment.path + "net.model")
            dict_names = {}
            for (name, param) in metalearner.net.named_parameters():
                dict_names[name] = param.adaptation

            my_experiment.add_result("Layers meta values", dict_names)
            my_experiment.add_result("Meta loss", loss_history)
            my_experiment.add_result("Adaptation loss",
                                     adaptation_loss_history)
            my_experiment.add_result("Running adaption loss",
                                     adaptation_running_loss_history)
            my_experiment.store_json()
Exemplo n.º 11
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    my_experiment = experiment(args.name,
                               args,
                               "../results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    args.classes = list(range(963))

    dataset = df.DatasetFactory.get_dataset(args.dataset,
                                            background=True,
                                            train=True)
    dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                 background=True,
                                                 train=False)

    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    logger.info("Train set length = %d", len(iterator_train) * 5)
    logger.info("Test set length = %d", len(iterator_test) * 5)
    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset_test)

    config = mf.ModelFactory.get_model("na", args.dataset)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = MetaLearingClassification(args, config).to(device)

    for name, param in maml.named_parameters():
        param.learn = True
    for name, param in maml.net.named_parameters():
        param.learn = True

    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("net.vars." + str(temp))

    for name, param in maml.named_parameters():
        logger.info(name)
        if name in frozen_layers:
            logger.info("Freeezing name %s", str(name))
            param.learn = False

    # Update the classifier
    list_of_params = list(filter(lambda x: x.learn, maml.parameters()))
    list_of_names = list(filter(lambda x: x[1].learn, maml.named_parameters()))

    for a in list_of_names:
        logger.info("Unfrozen layers for rep learning = %s", a[0])

    for step in range(args.epoch):

        t1 = np.random.choice(args.classes,
                              np.random.randint(1, args.tasks + 1),
                              replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators, d_rand_iterator, steps=args.update_step)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 40 == 0:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
        if step % 300 == 0:
            correct = 0
            torch.save(maml.net, my_experiment.path + "learner.model")
            for img, target in iterator_test:
                with torch.no_grad():
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml.net(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    correct += torch.eq(pred_q, target).sum().item() / len(img)
            writer.add_scalar('/metatrain/test/classifier/accuracy',
                              correct / len(iterator_test), step)
            logger.info("Test Accuracy = %s",
                        str(correct / len(iterator_test)))
            correct = 0
            for img, target in iterator_train:
                with torch.no_grad():
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml.net(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)
                    pred_q = (logits_q).argmax(dim=1)
                    correct += torch.eq(pred_q, target).sum().item() / len(img)

            logger.info("Train Accuracy = %s",
                        str(correct / len(iterator_train)))
            writer.add_scalar('/metatrain/train/classifier/accuracy',
                              correct / len(iterator_train), step)
Exemplo n.º 12
0
def main(args):
    # Placeholder variables
    old_accs = [0]
    old_meta_losses = [2.**30, 0]

    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "./results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")
    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(963))
    args.traj_classes = list(range(int(963 / 2), 963))

    dataset = df.DatasetFactory.get_dataset(args.dataset,
                                            background=True,
                                            train=True,
                                            all=True)
    dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                 background=True,
                                                 train=False,
                                                 all=True)
    # print("ONE ITEM", len(dataset.__getitem__(0)),dataset.__getitem__(0)[0].shape,dataset.__getitem__(0)[1])
    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)
    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)
    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset_test)
    # print("NUM CLASSES",args.classes)
    config = mf.ModelFactory.get_model("na", args.dataset)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    # maml = MetaLearingClassification(args, config).to(device)
    maml = MetaLearingClassification(args, config).to(device)
    utils.freeze_layers(args.rln, maml)  # freeze layers

    for step in range(args.steps):  #epoch
        print("STEP: ", step)
        t1 = np.random.choice(args.traj_classes, args.tasks,
                              replace=False)  #sample sine waves
        # print("TRAJ CLASSES<",args.tasks)
        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))
        # print("ANNOYINGNESS",d_traj_iterators)
        d_rand_iterator = sampler.get_complete_iterator()

        # Sample trajectory and random batch (support and query)
        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        # One training loop
        accs, loss = maml(x_spt, y_spt, x_qry, y_qry, step, old_accs,
                          old_meta_losses, args, config)

        # if loss[-2] >= old_meta_losses[-2]: #if training improves it,
        #     maml.set_self(other.get_self_state_dict())
        #     old_meta_losses = loss

        # else: #if not improved
        #     other.set_self(maml.get_self_state_dict())

        # Evaluation during training for sanity checks
        if step % 40 == 39:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
        if step % 300 == 299:
            utils.log_accuracy(maml, my_experiment, iterator_test, device,
                               writer, step)
            utils.log_accuracy(maml, my_experiment, iterator_train, device,
                               writer, step)

        torch.save(maml.net, my_experiment.path + "omniglot_classifier.model")
Exemplo n.º 13
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    my_experiment = experiment(args.name, args, "../results/", commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")
    print(args)

    tasks = list(range(400))
    logger = logging.getLogger('experiment')

    sampler = ts.SamplerFactory.get_sampler("Sin2", tasks, None, None, capacity=401)

    config = mf.ModelFactory.get_model("na", "Sin", in_channels=11, num_actions=30, width=args.width)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = MetaLearnerRegression(args, config).to(device)

    for name, param in maml.named_parameters():
        param.learn = True
    for name, param in maml.net.named_parameters():
        param.learn = True
    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logger.info(maml)
    logger.info('Total trainable tensors: %d', num)
    #
    accuracy = 0

    frozen_layers = []
    for temp in range(args.total_frozen * 2):
        frozen_layers.append("net.vars." + str(temp))
    logger.info("Frozen layers = %s", " ".join(frozen_layers))

    opt = torch.optim.Adam(maml.parameters(), lr=args.lr)
    meta_optim = torch.optim.lr_scheduler.MultiStepLR(opt, [5000, 8000], 0.2)

    for step in range(args.epoch):
        if step %300 == 0:
            print(step)
        for heads in range(30):
            t1 = tasks
            # print(tasks)
            iterators = []
            if not args.baseline:
                for t in range(heads*10, heads*10+10):
                    # print(sampler.sample_task([t]))
                    # print(t)
                    iterators.append(sampler.sample_task([t]))

            else:
                iterators.append(sampler.get_another_complete_iterator())

            x_spt, y_spt, x_qry, y_qry = construct_set(iterators, sampler, steps=args.update_step, offset =heads*10)

            if torch.cuda.is_available():
                x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(), x_qry.cuda(), y_qry.cuda()
            # print(x_spt, y_spt)
            net = maml.net
            logits = net(x_qry[0], None, bn_training=False)

            logits_select = []
            for no, val in enumerate(y_qry[0, :, 1].long()):
                # print(y_qry[0, :, 1].long())
                logits_select.append(logits[no, val])

            logits = torch.stack(logits_select).unsqueeze(1)

            loss = F.mse_loss(logits, y_qry[0, :, 0].unsqueeze(1))
            opt.zero_grad()
            loss.backward()
            opt.step()
            meta_optim.step()
            # print(loss)
            accuracy = accuracy * 0.95 + 0.05 * loss
            if step % 500 == 0:
                writer.add_scalar('/metatrain/train/accuracy', loss, step)
                writer.add_scalar('/metatrain/train/runningaccuracy', accuracy, step)
                logger.info("Running average of accuracy = %s", str(accuracy.item()))

            if step%500 == 0:
                torch.save(maml.net, my_experiment.path + "learner.model")
Exemplo n.º 14
0
def main(args):
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    np.random.seed(args.seed)

    my_experiment = experiment(args.name, args, "../results/")

    dataset = df.DatasetFactory.get_dataset(args.dataset)

    if args.dataset == "CIFAR100":
        args.classes = list(range(50))

    if args.dataset == "omniglot":
        iterator = torch.utils.data.DataLoader(utils.remove_classes_omni(
            dataset, list(range(963))),
                                               batch_size=256,
                                               shuffle=True,
                                               num_workers=1)
    else:
        iterator = torch.utils.data.DataLoader(utils.remove_classes(
            dataset, args.classes),
                                               batch_size=256,
                                               shuffle=True,
                                               num_workers=1)

    logger.info(str(args))

    config = mf.ModelFactory.get_model("na", args.dataset)

    maml = learner.Learner(config).to(device)

    opt = torch.optim.Adam(maml.parameters(), lr=args.lr)

    for e in range(args.epoch):
        correct = 0
        for img, y in iterator:
            if e == 20:
                opt = torch.optim.Adam(maml.parameters(), lr=0.00001)
                logger.info("Changing LR from %f to %f", 0.0001, 0.00001)
            img = img.to(device)
            y = y.to(device)
            pred = maml(img)
            feature = F.relu(maml(img, feature=True))
            avg_feature = feature.mean(0)

            beta = args.beta
            beta_hat = avg_feature

            loss_rec = ((beta / (beta_hat + 0.0001)) -
                        torch.log(beta / (beta_hat + 0.0001)) - 1)
            # loss_rec = (beta / (beta_hat)
            loss_rec = loss_rec * (beta_hat > beta).float()

            loss_sparse = loss_rec

            if args.l1:
                loss_sparse = feature.mean(0)
            loss_sparse = loss_sparse.mean()

            opt.zero_grad()
            loss = F.cross_entropy(pred, y)
            loss_sparse.backward(retain_graph=True)
            loss.backward()
            opt.step()
            correct += (pred.argmax(1) == y).sum().float() / len(y)
        logger.info("Accuracy at epoch %d = %s", e,
                    str(correct / len(iterator)))
        torch.save(maml, my_experiment.path + "model.net")
Exemplo n.º 15
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    my_experiment = experiment(args.name, args, "../results/", args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')
    logger.setLevel(logging.INFO)
    total_clases = 10

    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("vars." + str(temp))
    logger.info("Frozen layers = %s", " ".join(frozen_layers))
    #for v in range(6):
    #    frozen_layers.append("vars_bn.{0}".format(v))

    final_results_all = []
    temp_result = []
    total_clases = args.schedule
    for tot_class in total_clases:
        lr_list = [
            0.001, 0.0006, 0.0004, 0.00035, 0.0003, 0.00025, 0.0002, 0.00015,
            0.0001, 0.00009, 0.00008, 0.00006, 0.00003, 0.00001
        ]
        lr_all = []
        for lr_search in range(10):

            keep = np.random.choice(list(range(650)), tot_class, replace=False)

            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=True,
                                              background=False,
                                              path=args.dataset_path), keep)
            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset,
                                           False,
                                           classes=total_clases),
                batch_size=1,
                shuffle=args.iid,
                num_workers=2)
            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=not args.test,
                                              background=False,
                                              path=args.dataset_path), keep)
            iterator = torch.utils.data.DataLoader(dataset,
                                                   batch_size=1,
                                                   shuffle=False,
                                                   num_workers=1)

            print(args)

            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')

            results_mem_size = {}

            for mem_size in [args.memory]:
                max_acc = -10
                max_lr = -10
                for lr in lr_list:

                    print(lr)
                    maml = torch.load(args.model, map_location='cpu')

                    if args.scratch:
                        config = mf.ModelFactory.get_model("OML", args.dataset)
                        maml = learner.Learner(config)
                        # maml = MetaLearingClassification(args, config).to(device).net

                    maml = maml.to(device)

                    for name, param in maml.named_parameters():
                        param.learn = True

                    for name, param in maml.named_parameters():
                        # logger.info(name)
                        if name in frozen_layers:
                            param.learn = False

                        else:
                            if args.reset:
                                w = nn.Parameter(torch.ones_like(param))
                                # logger.info("W shape = %s", str(len(w.shape)))
                                if len(w.shape) > 1:
                                    torch.nn.init.kaiming_normal_(w)
                                else:
                                    w = nn.Parameter(torch.zeros_like(param))
                                param.data = w
                                param.learn = True

                    frozen_layers = []
                    for temp in range(args.rln * 2):
                        frozen_layers.append("vars." + str(temp))

                    torch.nn.init.kaiming_normal_(maml.parameters()[-2])
                    w = nn.Parameter(torch.zeros_like(maml.parameters()[-1]))
                    maml.parameters()[-1].data = w

                    if args.neuromodulation:
                        weights2reset = ["vars_26"]
                        biases2reset = ["vars_27"]
                    else:
                        weights2reset = ["vars_14"]
                        biases2reset = ["vars_15"]

                    for n, a in maml.named_parameters():
                        n = n.replace(".", "_")

                        if n in weights2reset:

                            w = nn.Parameter(torch.ones_like(a)).to(device)
                            torch.nn.init.kaiming_normal_(w)
                            a.data = w

                        if n in biases2reset:

                            w = nn.Parameter(torch.zeros_like(a)).to(device)
                            a.data = w

                    filter_list = ["vars.{0}".format(v) for v in range(6)]

                    logger.info("Filter list = %s", ",".join(filter_list))

                    list_of_names = list(
                        map(
                            lambda x: x[1],
                            list(
                                filter(lambda x: x[0] not in filter_list,
                                       maml.named_parameters()))))

                    list_of_params = list(
                        filter(lambda x: x.learn, maml.parameters()))
                    list_of_names = list(
                        filter(lambda x: x[1].learn, maml.named_parameters()))

                    if args.scratch or args.no_freeze:
                        print("Empty filter list")
                        list_of_params = maml.parameters()

                    for x in list_of_names:
                        logger.info("Unfrozen layer = %s", str(x[0]))
                    opt = torch.optim.Adam(list_of_params, lr=lr)

                    for _ in range(0, args.epoch):
                        for img, y in iterator_sorted:
                            img = img.to(device)
                            y = y.to(device)

                            pred = maml(img)
                            opt.zero_grad()
                            loss = F.cross_entropy(pred, y)
                            loss.backward()
                            opt.step()

                    logger.info("Result after one epoch for LR = %f", lr)
                    correct = 0
                    for img, target in iterator:
                        img = img.to(device)
                        target = target.to(device)
                        logits_q = maml(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)

                        pred_q = (logits_q).argmax(dim=1)

                        correct += torch.eq(pred_q,
                                            target).sum().item() / len(img)

                    logger.info(str(correct / len(iterator)))
                    if (correct / len(iterator) > max_acc):
                        max_acc = correct / len(iterator)
                        max_lr = lr

                lr_all.append(max_lr)
                results_mem_size[mem_size] = (max_acc, max_lr)
                logger.info("Final Max Result = %s", str(max_acc))
                writer.add_scalar('/finetune/best_' + str(lr_search), max_acc,
                                  tot_class)
            temp_result.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Temp Results = %s", str(results_mem_size))

            my_experiment.results["Temp Results"] = temp_result
            my_experiment.store_json()
            print("LR RESULTS = ", temp_result)

        from scipy import stats
        best_lr = float(stats.mode(lr_all)[0][0])
        logger.info("BEST LR %s= ", str(best_lr))

        for aoo in range(args.runs):

            keep = np.random.choice(list(range(650)), tot_class, replace=False)

            if args.dataset == "omniglot":

                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot",
                                                  train=True,
                                                  background=False), keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter_omni(dataset,
                                               False,
                                               classes=total_clases),
                    batch_size=1,
                    shuffle=args.iid,
                    num_workers=2)
                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot",
                                                  train=not args.test,
                                                  background=False), keep)
                iterator = torch.utils.data.DataLoader(dataset,
                                                       batch_size=1,
                                                       shuffle=False,
                                                       num_workers=1)
            elif args.dataset == "CIFAR100":
                keep = np.random.choice(list(range(50, 100)), tot_class)
                dataset = utils.remove_classes(
                    df.DatasetFactory.get_dataset(args.dataset, train=True),
                    keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter(dataset, False, classes=tot_class),
                    batch_size=16,
                    shuffle=args.iid,
                    num_workers=2)
                dataset = utils.remove_classes(
                    df.DatasetFactory.get_dataset(args.dataset, train=False),
                    keep)
                iterator = torch.utils.data.DataLoader(dataset,
                                                       batch_size=128,
                                                       shuffle=False,
                                                       num_workers=1)
            print(args)

            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')

            results_mem_size = {}

            for mem_size in [args.memory]:
                max_acc = -10
                max_lr = -10

                lr = best_lr

                maml = torch.load(args.model, map_location='cpu')

                if args.scratch:
                    config = mf.ModelFactory.get_model("MRCL", args.dataset)
                    maml = learner.Learner(config)

                maml = maml.to(device)

                for name, param in maml.named_parameters():
                    param.learn = True

                for name, param in maml.named_parameters():
                    # logger.info(name)
                    if name in frozen_layers:
                        param.learn = False
                    else:
                        if args.reset:
                            w = nn.Parameter(torch.ones_like(param))
                            if len(w.shape) > 1:
                                torch.nn.init.kaiming_normal_(w)
                            else:
                                w = nn.Parameter(torch.zeros_like(param))
                            param.data = w
                            param.learn = True

                frozen_layers = []
                for temp in range(args.rln * 2):
                    frozen_layers.append("vars." + str(temp))

                torch.nn.init.kaiming_normal_(maml.parameters()[-2])
                w = nn.Parameter(torch.zeros_like(maml.parameters()[-1]))
                maml.parameters()[-1].data = w

                for n, a in maml.named_parameters():
                    n = n.replace(".", "_")
                    if args.neuromodulation:
                        weights2reset = ["vars_26"]
                        biases2reset = ["vars_27"]
                    else:
                        weights2reset = ["vars_14"]
                        biases2reset = ["vars_15"]

                    for n, a in maml.named_parameters():
                        n = n.replace(".", "_")

                        if n in weights2reset:

                            w = nn.Parameter(torch.ones_like(a)).to(device)
                            torch.nn.init.kaiming_normal_(w)
                            a.data = w

                        if n in biases2reset:

                            w = nn.Parameter(torch.zeros_like(a)).to(device)
                            a.data = w

                correct = 0
                for img, target in iterator:
                    with torch.no_grad():

                        img = img.to(device)
                        target = target.to(device)
                        logits_q = maml(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)
                        pred_q = (logits_q).argmax(dim=1)
                        correct += torch.eq(pred_q,
                                            target).sum().item() / len(img)

                logger.info("Pre-epoch accuracy %s",
                            str(correct / len(iterator)))

                filter_list = ["vars.{0}".format(v) for v in range(6)]

                logger.info("Filter list = %s", ",".join(filter_list))

                list_of_names = list(
                    map(
                        lambda x: x[1],
                        list(
                            filter(lambda x: x[0] not in filter_list,
                                   maml.named_parameters()))))

                list_of_params = list(
                    filter(lambda x: x.learn, maml.parameters()))
                list_of_names = list(
                    filter(lambda x: x[1].learn, maml.named_parameters()))
                if args.scratch or args.no_freeze:
                    print("Empty filter list")
                    list_of_params = maml.parameters()

                for x in list_of_names:
                    logger.info("Unfrozen layer = %s", str(x[0]))
                opt = torch.optim.Adam(list_of_params, lr=lr)

                for _ in range(0, args.epoch):
                    for img, y in iterator_sorted:
                        img = img.to(device)
                        y = y.to(device)
                        pred = maml(img)
                        opt.zero_grad()
                        loss = F.cross_entropy(pred, y)
                        loss.backward()
                        opt.step()

                logger.info("Result after one epoch for LR = %f", lr)

                correct = 0
                for img, target in iterator:
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml(img,
                                    vars=None,
                                    bn_training=False,
                                    feature=False)

                    pred_q = (logits_q).argmax(dim=1)

                    correct += torch.eq(pred_q, target).sum().item() / len(img)

                logger.info(str(correct / len(iterator)))
                if (correct / len(iterator) > max_acc):
                    max_acc = correct / len(iterator)
                    max_lr = lr

                lr_list = [max_lr]
                results_mem_size[mem_size] = (max_acc, max_lr)
                logger.info("Final Max Result = %s", str(max_acc))
                writer.add_scalar('/finetune/best_' + str(aoo), max_acc,
                                  tot_class)
            final_results_all.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Final results = %s", str(results_mem_size))

            my_experiment.results["Final Results"] = final_results_all
            my_experiment.store_json()
            print("FINAL RESULTS = ", final_results_all)

    writer.close()
Exemplo n.º 16
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    my_experiment = experiment(args.name, args, "../results/", args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')
    logger.setLevel(logging.INFO)
    total_clases = 10

    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("vars." + str(temp))
    logger.info("Frozen layers = %s", " ".join(frozen_layers))

    #
    final_results_all = []
    total_clases = [10, 50, 75, 100, 150, 200]
    for tot_class in total_clases:
        lr_list = [0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001]
        for aoo in range(0, 20):

            keep = np.random.choice(list(range(200)), tot_class)

            if args.dataset == "omniglot":

                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot",
                                                  train=True,
                                                  background=False), keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter_omni(dataset,
                                               False,
                                               classes=total_clases),
                    batch_size=1,
                    shuffle=args.iid,
                    num_workers=2)
                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot",
                                                  train=not args.test,
                                                  background=False), keep)
                iterator = torch.utils.data.DataLoader(dataset,
                                                       batch_size=32,
                                                       shuffle=False,
                                                       num_workers=1)
            elif args.dataset == "CIFAR100":
                keep = np.random.choice(list(range(50, 100)), tot_class)
                dataset = utils.remove_classes(
                    df.DatasetFactory.get_dataset(args.dataset, train=True),
                    keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter(dataset, False, classes=tot_class),
                    batch_size=16,
                    shuffle=args.iid,
                    num_workers=2)
                dataset = utils.remove_classes(
                    df.DatasetFactory.get_dataset(args.dataset, train=False),
                    keep)
                iterator = torch.utils.data.DataLoader(dataset,
                                                       batch_size=128,
                                                       shuffle=False,
                                                       num_workers=1)
            # sampler = ts.MNISTSampler(list(range(0, total_clases)), dataset)
            #
            print(args)

            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')

            results_mem_size = {}

            for mem_size in [args.memory]:
                max_acc = -10
                max_lr = -10
                for lr in lr_list:

                    print(lr)
                    # for lr in [0.001, 0.0003, 0.0001, 0.00003, 0.00001]:
                    maml = torch.load(args.model, map_location='cpu')

                    if args.scratch:
                        config = mf.ModelFactory.get_model("na", args.dataset)
                        maml = learner.Learner(config)
                        # maml = MetaLearingClassification(args, config).to(device).net

                    maml = maml.to(device)

                    for name, param in maml.named_parameters():
                        param.learn = True

                    for name, param in maml.named_parameters():
                        # logger.info(name)
                        if name in frozen_layers:
                            # logger.info("Freeezing name %s", str(name))
                            param.learn = False
                            # logger.info(str(param.requires_grad))
                        else:
                            if args.reset:
                                w = nn.Parameter(torch.ones_like(param))
                                # logger.info("W shape = %s", str(len(w.shape)))
                                if len(w.shape) > 1:
                                    torch.nn.init.kaiming_normal_(w)
                                else:
                                    w = nn.Parameter(torch.zeros_like(param))
                                param.data = w
                                param.learn = True

                    frozen_layers = []
                    for temp in range(args.rln * 2):
                        frozen_layers.append("vars." + str(temp))

                    torch.nn.init.kaiming_normal_(maml.parameters()[-2])
                    w = nn.Parameter(torch.zeros_like(maml.parameters()[-1]))
                    maml.parameters()[-1].data = w

                    for n, a in maml.named_parameters():
                        n = n.replace(".", "_")
                        # logger.info("Name = %s", n)
                        if n == "vars_14":
                            w = nn.Parameter(torch.ones_like(a))
                            # logger.info("W shape = %s", str(w.shape))
                            torch.nn.init.kaiming_normal_(w)
                            a.data = w
                        if n == "vars_15":
                            w = nn.Parameter(torch.zeros_like(a))
                            a.data = w

                    correct = 0

                    for img, target in iterator:
                        with torch.no_grad():
                            img = img.to(device)
                            target = target.to(device)
                            logits_q = maml(img,
                                            vars=None,
                                            bn_training=False,
                                            feature=False)
                            pred_q = (logits_q).argmax(dim=1)
                            correct += torch.eq(pred_q,
                                                target).sum().item() / len(img)

                    logger.info("Pre-epoch accuracy %s",
                                str(correct / len(iterator)))

                    filter_list = [
                        "vars.0", "vars.1", "vars.2", "vars.3", "vars.4",
                        "vars.5"
                    ]

                    logger.info("Filter list = %s", ",".join(filter_list))
                    list_of_names = list(
                        map(
                            lambda x: x[1],
                            list(
                                filter(lambda x: x[0] not in filter_list,
                                       maml.named_parameters()))))

                    list_of_params = list(
                        filter(lambda x: x.learn, maml.parameters()))
                    list_of_names = list(
                        filter(lambda x: x[1].learn, maml.named_parameters()))
                    if args.scratch or args.no_freeze:
                        print("Empty filter list")
                        list_of_params = maml.parameters()
                    #
                    for x in list_of_names:
                        logger.info("Unfrozen layer = %s", str(x[0]))
                    opt = torch.optim.Adam(list_of_params, lr=lr)
                    import module.replay as rep
                    res_sampler = rep.ReservoirSampler(mem_size)
                    for _ in range(0, args.epoch):
                        for img, y in iterator_sorted:

                            if mem_size > 0:
                                res_sampler.update_buffer(zip(img, y))
                                res_sampler.update_observations(len(img))
                                img = img.to(device)
                                y = y.to(device)
                                img2, y2 = res_sampler.sample_buffer(16)
                                img2 = img2.to(device)
                                y2 = y2.to(device)

                                img = torch.cat([img, img2], dim=0)
                                y = torch.cat([y, y2], dim=0)
                            else:
                                img = img.to(device)
                                y = y.to(device)

                            pred = maml(img)
                            opt.zero_grad()
                            loss = F.cross_entropy(pred, y)
                            loss.backward()
                            opt.step()

                    logger.info("Result after one epoch for LR = %f", lr)
                    correct = 0
                    for img, target in iterator:
                        img = img.to(device)
                        target = target.to(device)
                        logits_q = maml(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)

                        pred_q = (logits_q).argmax(dim=1)

                        correct += torch.eq(pred_q,
                                            target).sum().item() / len(img)

                    logger.info(str(correct / len(iterator)))
                    if (correct / len(iterator) > max_acc):
                        max_acc = correct / len(iterator)
                        max_lr = lr

                lr_list = [max_lr]
                results_mem_size[mem_size] = (max_acc, max_lr)
                logger.info("Final Max Result = %s", str(max_acc))
                writer.add_scalar('/finetune/best_' + str(aoo), max_acc,
                                  tot_class)
                # quit()
            final_results_all.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Final results = %s", str(results_mem_size))

            my_experiment.results["Final Results"] = final_results_all
            my_experiment.store_json()
            print("FINAL RESULTS = ", final_results_all)
    writer.close()
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    my_experiment = experiment(args.name,
                               args,
                               "/data5/jlindsey/continual/results",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    args.classes = list(range(963))

    print('dataset', args.dataset, args.dataset == "imagenet")

    if args.dataset != "imagenet":

        dataset = df.DatasetFactory.get_dataset(args.dataset,
                                                background=True,
                                                train=True,
                                                all=True)
        dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                     background=True,
                                                     train=False,
                                                     all=True)

    else:
        args.classes = list(range(64))
        dataset = imgnet.MiniImagenet(args.imagenet_path, mode='train')
        dataset_test = imgnet.MiniImagenet(args.imagenet_path, mode='test')

    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    logger.info("Train set length = %d", len(iterator_train) * 5)
    logger.info("Test set length = %d", len(iterator_test) * 5)
    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset_test)

    config = mf.ModelFactory.get_model(
        args.model_type,
        args.dataset,
        width=args.width,
        num_extra_dense_layers=args.num_extra_dense_layers)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    if args.oja or args.hebb:
        maml = OjaMetaLearingClassification(args, config).to(device)
    else:
        print('starting up')
        maml = MetaLearingClassification(args, config).to(device)

    import sys
    if args.from_saved:
        maml.net = torch.load(args.model)
        if args.use_derivative:
            maml.net.use_derivative = True
        maml.net.optimize_out = args.optimize_out
        if maml.net.optimize_out:
            maml.net.feedback_strength_vars.append(
                torch.nn.Parameter(maml.net.init_feedback_strength *
                                   torch.ones(1).cuda()))

        if args.reset_feedback_strength:
            for fv in maml.net.feedback_strength_vars:
                w = nn.Parameter(torch.ones_like(fv) * args.feedback_strength)
                fv.data = w

        if args.reset_feedback_vars:

            print('howdy', maml.net.num_feedback_layers)

            maml.net.feedback_vars = nn.ParameterList()
            maml.net.feedback_vars_bundled = []

            maml.net.vars_plasticity = nn.ParameterList()
            maml.net.plasticity = nn.ParameterList()

            maml.net.neuron_plasticity = nn.ParameterList()

            maml.net.layer_plasticity = nn.ParameterList()

            starting_width = 84
            cur_width = starting_width
            num_outputs = maml.net.config[-1][1][0]
            for i, (name, param) in enumerate(maml.net.config):
                print('yo', i, name, param)
                if name == 'conv2d':
                    print('in conv2d')
                    stride = param[4]
                    padding = param[5]

                    #print('cur_width', cur_width, param[3])
                    cur_width = (cur_width + 2 * padding - param[3] +
                                 stride) // stride

                    maml.net.vars_plasticity.append(
                        nn.Parameter(torch.ones(*param[:4]).cuda()))
                    maml.net.vars_plasticity.append(
                        nn.Parameter(torch.ones(param[0]).cuda()))
                    #self.activations_list.append([])
                    maml.net.plasticity.append(
                        nn.Parameter(
                            maml.net.init_plasticity *
                            torch.ones(param[0], param[1] * param[2] *
                                       param[3]).cuda()))  #not implemented
                    maml.net.neuron_plasticity.append(
                        nn.Parameter(torch.zeros(1).cuda()))  #not implemented

                    maml.net.layer_plasticity.append(
                        nn.Parameter(maml.net.init_plasticity *
                                     torch.ones(1).cuda()))  #not implemented

                    feedback_var = []

                    for fl in range(maml.net.num_feedback_layers):
                        print('doing fl')
                        in_dim = maml.net.width
                        out_dim = maml.net.width
                        if fl == maml.net.num_feedback_layers - 1:
                            out_dim = param[0] * cur_width * cur_width
                        if fl == 0:
                            in_dim = num_outputs
                        feedback_w_shape = [out_dim, in_dim]
                        feedback_w = nn.Parameter(
                            torch.ones(feedback_w_shape).cuda())
                        feedback_b = nn.Parameter(torch.zeros(out_dim).cuda())
                        torch.nn.init.kaiming_normal_(feedback_w)
                        feedback_var.append((feedback_w, feedback_b))
                        print('adding')
                        maml.net.feedback_vars.append(feedback_w)
                        maml.net.feedback_vars.append(feedback_b)

                    #maml.net.feedback_vars_bundled.append(feedback_var)
                    #maml.net.feedback_vars_bundled.append(None)#bias feedback -- not implemented

                    #'''

                    maml.net.feedback_vars_bundled.append(
                        nn.Parameter(torch.zeros(
                            1)))  #weight feedback -- not implemented
                    maml.net.feedback_vars_bundled.append(
                        nn.Parameter(
                            torch.zeros(1)))  #bias feedback -- not implemented

                elif name == 'linear':
                    maml.net.vars_plasticity.append(
                        nn.Parameter(torch.ones(*param).cuda()))
                    maml.net.vars_plasticity.append(
                        nn.Parameter(torch.ones(param[0]).cuda()))
                    #self.activations_list.append([])
                    maml.net.plasticity.append(
                        nn.Parameter(maml.net.init_plasticity *
                                     torch.ones(*param).cuda()))
                    maml.net.neuron_plasticity.append(
                        nn.Parameter(maml.net.init_plasticity *
                                     torch.ones(param[0]).cuda()))
                    maml.net.layer_plasticity.append(
                        nn.Parameter(maml.net.init_plasticity *
                                     torch.ones(1).cuda()))

                    feedback_var = []

                    for fl in range(maml.net.num_feedback_layers):
                        in_dim = maml.net.width
                        out_dim = maml.net.width
                        if fl == maml.net.num_feedback_layers - 1:
                            out_dim = param[0]
                        if fl == 0:
                            in_dim = num_outputs
                        feedback_w_shape = [out_dim, in_dim]
                        feedback_w = nn.Parameter(
                            torch.ones(feedback_w_shape).cuda())
                        feedback_b = nn.Parameter(torch.zeros(out_dim).cuda())
                        torch.nn.init.kaiming_normal_(feedback_w)
                        feedback_var.append((feedback_w, feedback_b))
                        maml.net.feedback_vars.append(feedback_w)
                        maml.net.feedback_vars.append(feedback_b)
                    maml.net.feedback_vars_bundled.append(feedback_var)
                    maml.net.feedback_vars_bundled.append(
                        None)  #bias feedback -- not implemented

        maml.init_stuff(args)

    maml.net.optimize_out = args.optimize_out
    if maml.net.optimize_out:
        maml.net.feedback_strength_vars.append(
            torch.nn.Parameter(maml.net.init_feedback_strength *
                               torch.ones(1).cuda()))
    #I recently un-indented this until the maml.init_opt() line.  If stuff stops working, try re-indenting this block
    if args.zero_non_output_plasticity:
        for index in range(len(maml.net.vars_plasticity) - 2):
            maml.net.vars_plasticity[index] = torch.nn.Parameter(
                maml.net.vars_plasticity[index] * 0)
        if args.oja or args.hebb:
            for index in range(len(maml.net.plasticity) - 1):
                if args.plasticity_rank1:
                    maml.net.plasticity[index] = torch.nn.Parameter(
                        torch.zeros(1).cuda())
                else:
                    maml.net.plasticity[index] = torch.nn.Parameter(
                        maml.net.plasticity[index] * 0)
                    maml.net.layer_plasticity[index] = torch.nn.Parameter(
                        maml.net.layer_plasticity[index] * 0)
                    maml.net.neuron_plasticity[index] = torch.nn.Parameter(
                        maml.net.neuron_plasticity[index] * 0)

        if args.oja or args.hebb:
            for index in range(len(maml.net.vars_plasticity) - 2):
                maml.net.vars_plasticity[index] = torch.nn.Parameter(
                    maml.net.vars_plasticity[index] * 0)
    if args.zero_all_plasticity:
        print('zeroing plasticity')
        for index in range(len(maml.net.vars_plasticity)):
            maml.net.vars_plasticity[index] = torch.nn.Parameter(
                maml.net.vars_plasticity[index] * 0)
        for index in range(len(maml.net.plasticity)):
            if args.plasticity_rank1:
                maml.net.plasticity[index] = torch.nn.Parameter(
                    torch.zeros(1).cuda())
            else:

                maml.net.plasticity[index] = torch.nn.Parameter(
                    maml.net.plasticity[index] * 0)
                maml.net.layer_plasticity[index] = torch.nn.Parameter(
                    maml.net.layer_plasticity[index] * 0)
                maml.net.neuron_plasticity[index] = torch.nn.Parameter(
                    maml.net.neuron_plasticity[index] * 0)

    print('heyy', maml.net.feedback_vars)
    maml.init_opt()
    for name, param in maml.named_parameters():
        param.learn = True
    for name, param in maml.net.named_parameters():
        param.learn = True

    if args.freeze_out_plasticity:
        maml.net.plasticity[-1].requires_grad = False
    total_ff_vars = 2 * (6 + 2 + args.num_extra_dense_layers)
    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("net.vars." + str(temp))

    for temp in range(args.rln_end * 2):
        frozen_layers.append("net.vars." + str(total_ff_vars - 1 - temp))
    for name, param in maml.named_parameters():
        # logger.info(name)
        if name in frozen_layers:
            logger.info("RLN layer %s", str(name))
            param.learn = False

    # Update the classifier
    list_of_params = list(filter(lambda x: x.learn, maml.parameters()))
    list_of_names = list(filter(lambda x: x[1].learn, maml.named_parameters()))

    for a in list_of_names:
        logger.info("TLN layer = %s", a[0])

    for step in range(args.steps):
        '''
        print('plasticity')
        for p in maml.net.plasticity:
            print(p.size(), torch.sum(p), p)
        '''
        t1 = np.random.choice(
            args.classes, args.tasks, replace=False
        )  #np.random.randint(1, args.tasks + 1), replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            iid=args.iid)

        perm = np.random.permutation(args.tasks)

        old = []
        for i in range(y_spt.size()[0]):
            num = int(y_spt[i].cpu().numpy())
            if num not in old:
                old.append(num)
            y_spt[i] = torch.tensor(perm[old.index(num)])

        for i in range(y_qry.size()[1]):
            num = int(y_qry[0][i].cpu().numpy())
            y_qry[0][i] = torch.tensor(perm[old.index(num)])
        #print('hi', y_qry.size())
        #print('y_spt', y_spt)
        #print('y_qry', y_qry)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        #print('heyyyy', x_spt.size(), y_spt.size(), x_qry.size(), y_qry.size())
        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 1 == 0:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
        if step % 300 == 0:
            correct = 0
            torch.save(maml.net, my_experiment.path + "learner.model")
            for img, target in iterator_test:
                with torch.no_grad():
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml.net(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    correct += torch.eq(pred_q, target).sum().item() / len(img)
            writer.add_scalar('/metatrain/test/classifier/accuracy',
                              correct / len(iterator_test), step)
            logger.info("Test Accuracy = %s",
                        str(correct / len(iterator_test)))
            correct = 0
            for img, target in iterator_train:
                with torch.no_grad():

                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml.net(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)
                    pred_q = (logits_q).argmax(dim=1)
                    correct += torch.eq(pred_q, target).sum().item() / len(img)

            logger.info("Train Accuracy = %s",
                        str(correct / len(iterator_train)))
            writer.add_scalar('/metatrain/train/classifier/accuracy',
                              correct / len(iterator_train), step)
Exemplo n.º 18
0
def main():
    p = params.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)

    utils.set_seed(args['seed'])

    my_experiment = experiment(args['name'],
                               args,
                               "../results/",
                               commit_changes=False,
                               rank=0,
                               seed=1)

    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device('cuda:' + str(gpu_to_use))
        logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
    else:
        device = torch.device('cpu')

    dataset = df.DatasetFactory.get_dataset(args['dataset'],
                                            background=True,
                                            train=True,
                                            path=args["path"],
                                            all=True)

    iterator = torch.utils.data.DataLoader(dataset,
                                           batch_size=256,
                                           shuffle=True,
                                           num_workers=0)

    logger.info(str(args))

    config = mf.ModelFactory.get_model("na", args["dataset"])

    maml = learner.Learner(config).to(device)

    for k, v in maml.named_parameters():
        print(k, v.requires_grad)

    opt = torch.optim.Adam(maml.parameters(), lr=args["lr"])

    for e in range(args["epoch"]):
        correct = 0
        for img, y in tqdm(iterator):
            img = img.to(device)
            y = y.to(device)
            pred = maml(img)

            opt.zero_grad()
            loss = F.cross_entropy(pred, y.long())
            loss.backward()
            opt.step()
            correct += (pred.argmax(1) == y).sum().float() / len(y)
        logger.info("Accuracy at epoch %d = %s", e,
                    str(correct / len(iterator)))
        torch.save(maml, my_experiment.path + "model.net")
Exemplo n.º 19
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    my_experiment = experiment(args.name,
                               args,
                               "../results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")
    print(args)

    tasks = list(range(400))
    logger = logging.getLogger('experiment')

    sampler = ts.SamplerFactory.get_sampler("Sin",
                                            tasks,
                                            None,
                                            capacity=args.capacity + 1)

    config = mf.ModelFactory.get_model("na",
                                       "Sin",
                                       in_channels=args.capacity + 1,
                                       num_actions=1,
                                       width=args.width)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = MetaLearnerRegression(args, config).to(device)

    for name, param in maml.named_parameters():
        param.learn = True
    for name, param in maml.net.named_parameters():
        param.learn = True
    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logger.info(maml)
    logger.info('Total trainable tensors: %d', num)
    #
    accuracy = 0

    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("net.vars." + str(temp))
    logger.info("Frozen layers = %s", " ".join(frozen_layers))
    for step in range(args.epoch):

        if step == 0 and not args.no_freeze:
            for name, param in maml.named_parameters():
                logger.info(name)
                if name in frozen_layers:
                    logger.info("Freeezing name %s", str(name))
                    param.learn = False
                    logger.info(str(param.requires_grad))

            for name, param in maml.net.named_parameters():
                logger.info(name)
                if name in frozen_layers:
                    logger.info("Freeezing name %s", str(name))
                    param.learn = False
                    logger.info(str(param.requires_grad))

        t1 = np.random.choice(tasks, args.tasks, replace=False)

        iterators = []
        for t in t1:
            # print(sampler.sample_task([t]))
            iterators.append(sampler.sample_task([t]))

        x_traj, y_traj, x_rand, y_rand = construct_set(iterators,
                                                       sampler,
                                                       steps=args.update_step)

        if torch.cuda.is_available():
            x_traj, y_traj, x_rand, y_rand = x_traj.cuda(), y_traj.cuda(
            ), x_rand.cuda(), y_rand.cuda()
        # print(x_spt, y_spt)
        accs = maml(x_traj, y_traj, x_rand, y_rand)
        maml.meta_optim.step()

        if step in [0, 2000, 3000, 4000]:
            for param_group in maml.optimizer.param_groups:
                logger.info("Learning Rate at step %d = %s", step,
                            str(param_group['lr']))

        accuracy = accuracy * 0.95 + 0.05 * accs[-1]
        if step % 5 == 0:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            writer.add_scalar('/metatrain/train/runningaccuracy', accuracy,
                              step)
            logger.info("Running average of accuracy = %s", str(accuracy))
            logger.info('step: %d \t training acc (first, last) %s', step,
                        str(accs[0]) + "," + str(accs[-1]))

        if step % 100 == 0:
            counter = 0
            for name, _ in maml.net.named_parameters():
                counter += 1

            for lrs in [args.update_lr]:
                lr_results = {}
                lr_results[lrs] = []
                for temp in range(0, 20):
                    t1 = np.random.choice(tasks, args.tasks, replace=False)
                    iterators = []

                    for t in t1:
                        iterators.append(sampler.sample_task([t]))
                    x_traj, y_traj, x_rand, y_rand = construct_set(
                        iterators, sampler, steps=40, no_rand=args.no_rand)
                    if torch.cuda.is_available():
                        x_traj, y_traj, x_rand, y_rand = x_traj.cuda(
                        ), y_traj.cuda(), x_rand.cuda(), y_rand.cuda()

                    net = copy.deepcopy(maml.net)
                    net = net.to(device)
                    for params_old, params_new in zip(maml.net.parameters(),
                                                      net.parameters()):
                        params_new.learn = params_old.learn

                    list_of_params = list(
                        filter(lambda x: x.learn, net.parameters()))

                    optimizer = optim.SGD(list_of_params, lr=lrs)
                    for k in range(len(x_traj)):
                        logits = net(x_traj[k], None, bn_training=False)

                        logits_select = []
                        for no, val in enumerate(y_traj[k, :, 1].long()):
                            logits_select.append(logits[no, val])

                        logits = torch.stack(logits_select).unsqueeze(1)

                        loss = F.mse_loss(logits, y_traj[k, :, 0].unsqueeze(1))
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                    #
                    with torch.no_grad():
                        logits = net(x_rand[0], vars=None, bn_training=False)

                        logits_select = []
                        for no, val in enumerate(y_rand[0, :, 1].long()):
                            logits_select.append(logits[no, val])
                        logits = torch.stack(logits_select).unsqueeze(1)
                        loss_q = F.mse_loss(logits, y_rand[0, :,
                                                           0].unsqueeze(1))
                        lr_results[lrs].append(loss_q.item())

                logger.info("Avg MSE LOSS  for lr %s = %s", str(lrs),
                            str(np.mean(lr_results[lrs])))

            torch.save(maml.net, my_experiment.path + "learner.model")
Exemplo n.º 20
0
def main(args):
    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "../results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(963))

    args.traj_classes = list(range(963))
    #
    dataset = df.DatasetFactory.get_dataset(args.dataset,
                                            background=True,
                                            train=True,
                                            all=True)
    dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                 background=False,
                                                 train=True,
                                                 all=True)

    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset)

    sampler_test = ts.SamplerFactory.get_sampler(args.dataset,
                                                 list(range(600)),
                                                 dataset_test, dataset_test)

    config = mf.ModelFactory.get_model("na", "omniglot-fc")

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = MetaLearingClassification(args, config).to(device)

    utils.freeze_layers(args.rln, maml)

    for step in range(args.steps):

        t1 = np.random.choice(args.traj_classes, args.tasks, replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_few_shot_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)
        #
        # Evaluation during training for sanity checks
        if step % 20 == 0:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
            logger.info("Loss = %s", str(loss[-1].item()))
        if step % 600 == 599:
            torch.save(maml.net, my_experiment.path + "learner.model")
            accs_avg = None
            for temp_temp in range(0, 40):
                t1_test = np.random.choice(list(range(600)),
                                           args.tasks,
                                           replace=False)

                d_traj_test_iterators = []
                for t in t1_test:
                    d_traj_test_iterators.append(sampler_test.sample_task([t]))

                x_spt, y_spt, x_qry, y_qry = maml.sample_few_shot_training_data(
                    d_traj_test_iterators,
                    None,
                    steps=args.update_step,
                    reset=not args.no_reset)
                if torch.cuda.is_available():
                    x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
                    ), x_qry.cuda(), y_qry.cuda()

                accs, loss = maml.finetune(x_spt, y_spt, x_qry, y_qry)
                if accs_avg is None:
                    accs_avg = accs
                else:
                    accs_avg += accs
            logger.info("Loss = %s", str(loss[-1].item()))
            writer.add_scalar('/metatest/train/accuracy', accs_avg[-1] / 40,
                              step)
            logger.info('TEST: step: %d \t testing acc %s', step,
                        str(accs_avg / 40))
Exemplo n.º 21
0
def main():
    p = class_parser.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)

    utils.set_seed(args["seed"])

    if args["log_root"]:
        log_root = osp.join("./results", args["log_root"]) + "/"
    else:
        log_root = osp.join("./results/")

    my_experiment = experiment(
        args["name"],
        args,
        log_root,
        commit_changes=False,
        rank=0,
        seed=args["seed"],
    )
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger("experiment")

    # Using first 963 classes of the omniglot as the meta-training set
    # args["classes"] = list(range(963))
    args["classes"] = list(range(args["num_classes"]))
    print("Using classes:", args["num_classes"])
    # logger.info("Using classes:", str(args["num_classes"]))

    # args["traj_classes"] = list(range(int(963 / 2), 963))

    if torch.cuda.is_available():
        device = torch.device("cuda")
        use_cuda = True
    else:
        device = torch.device("cpu")
        use_cuda = False
    dataset_spt = df.DatasetFactory.get_dataset(
        args["dataset"],
        background=True,
        train=True,
        path=args["path"],
        # all=True,
        # all=False,
        all=args["all"],
        prefetch_gpu=args["prefetch_gpu"],
        device=device,
        resize=args["resize"],
        augment=args["augment_spt"],
    )
    dataset_qry = df.DatasetFactory.get_dataset(
        args["dataset"],
        background=True,
        train=True,
        path=args["path"],
        # all=True,
        # all=False,
        all=args["all"],
        prefetch_gpu=args["prefetch_gpu"],
        device=device,
        resize=args["resize"],
        augment=args["augment_qry"],
    )
    dataset_test = df.DatasetFactory.get_dataset(
        args["dataset"],
        background=True,
        train=False,
        path=args["path"],
        # all=True,
        # all=False,
        all=args["all"],
        resize=args["resize"],
        # augment=args["augment"],
    )

    logger.info(
        f"Support size: {len(dataset_spt)}, Query size: {len(dataset_qry)}, test size: {len(dataset_test)}"
    )
    # print(f"Support size: {len(dataset_spt)}, Query size: {len(dataset_qry)}, test size: {len(dataset_test)}")

    pin_memory = use_cuda
    if args["prefetch_gpu"]:
        num_workers = 0
        pin_memory = False
    else:
        num_workers = args["num_workers"]
    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=5,
        shuffle=True,
        num_workers=0,
        # pin_memory=pin_memory,
    )

    iterator_train = torch.utils.data.DataLoader(
        dataset_spt,
        batch_size=5,
        shuffle=True,
        num_workers=0,
        # pin_memory=pin_memory,
    )

    logger.info("Support sampler:")
    sampler_spt = ts.SamplerFactory.get_sampler(
        args["dataset"],
        args["classes"],
        dataset_spt,
        dataset_test,
        prefetch_gpu=args["prefetch_gpu"],
        use_cuda=use_cuda,
        num_workers=0,
    )
    logger.info("Query sampler:")
    sampler_qry = ts.SamplerFactory.get_sampler(
        args["dataset"],
        args["classes"],
        dataset_qry,
        dataset_test,
        prefetch_gpu=args["prefetch_gpu"],
        use_cuda=use_cuda,
        num_workers=0,
    )

    config = mf.ModelFactory.get_model(
        "na",
        args["dataset"],
        output_dimension=1000,
        resize=args["resize"],
    )

    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device("cuda:" + str(gpu_to_use))
        logger.info("Using gpu : %s", "cuda:" + str(gpu_to_use))
    else:
        device = torch.device("cpu")

    maml = MetaLearingClassification(args, config).to(device)

    for step in range(args["steps"]):

        t1 = np.random.choice(args["classes"], args["tasks"], replace=False)

        d_traj_iterators_spt = []
        d_traj_iterators_qry = []
        for t in t1:
            d_traj_iterators_spt.append(sampler_spt.sample_task([t]))
            d_traj_iterators_qry.append(sampler_qry.sample_task([t]))

        d_rand_iterator = sampler_spt.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data_paper(
            d_traj_iterators_spt,
            d_traj_iterators_qry,
            d_rand_iterator,
            steps=args["update_step"],
            reset=not args["no_reset"],
        )
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = (
                x_spt.to(device),
                y_spt.to(device),
                x_qry.to(device),
                y_qry.to(device),
            )

        #
        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)

        # Evaluation during training for sanity checks
        if step % 40 == 5:
            writer.add_scalar("/metatrain/train/accuracy", accs[-1], step)
            writer.add_scalar("/metatrain/train/loss", loss[-1], step)
            writer.add_scalar("/metatrain/train/accuracy0", accs[0], step)
            writer.add_scalar("/metatrain/train/loss0", loss[0], step)
            logger.info("step: %d \t training acc %s", step, str(accs))
            logger.info("step: %d \t training loss %s", step, str(loss))
        # Currently useless
        if (step % 300 == 3) or ((step + 1) == args["steps"]):
            torch.save(maml.net, my_experiment.path + "learner.model")
Exemplo n.º 22
0
def main():
    p = class_parser_eval.Parser()
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)

    utils.set_seed(args['seed'])

    my_experiment = experiment(args['name'],
                               args,
                               "../results/",
                               commit_changes=False,
                               rank=0,
                               seed=1)

    final_results_all = []
    temp_result = []
    args['schedule'] = [int(x) for x in args['schedule'].split(":")]
    total_clases = args['schedule']
    print(args["schedule"])
    for tot_class in total_clases:
        print("Classes current step = ", tot_class)
        lr_list = [0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001]
        lr_all = []
        for lr_search in range(0, 5):

            keep = np.random.choice(list(range(650)), tot_class, replace=False)

            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=True,
                                              background=False,
                                              path=args['path']), keep)
            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset,
                                           False,
                                           classes=total_clases),
                batch_size=1,
                shuffle=args['iid'],
                num_workers=2)
            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=not args['test'],
                                              background=False,
                                              path=args['path']), keep)
            iterator = torch.utils.data.DataLoader(dataset,
                                                   batch_size=32,
                                                   shuffle=False,
                                                   num_workers=1)

            gpu_to_use = rank % args["gpus"]
            if torch.cuda.is_available():
                device = torch.device('cuda:' + str(gpu_to_use))
                logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
            else:
                device = torch.device('cpu')

            config = mf.ModelFactory.get_model("na",
                                               args['dataset'],
                                               output_dimension=1000)
            max_acc = -1000
            for lr in lr_list:

                print(lr)
                maml = load_model(args, config)
                maml = maml.to(device)

                opt = torch.optim.Adam(maml.get_adaptation_parameters(), lr=lr)

                for _ in range(0, 1):
                    for img, y in iterator_sorted:
                        img = img.to(device)
                        y = y.to(device)

                        pred = maml(img)
                        opt.zero_grad()
                        loss = F.cross_entropy(pred, y)
                        loss.backward()
                        opt.step()

                logger.info("Result after one epoch for LR = %f", lr)
                correct = 0
                for img, target in iterator:
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml(img)

                    pred_q = (logits_q).argmax(dim=1)

                    correct += torch.eq(pred_q, target).sum().item() / len(img)

                logger.info(str(correct / len(iterator)))
                if (correct / len(iterator) > max_acc):
                    max_acc = correct / len(iterator)
                    max_lr = lr

            lr_all.append(max_lr)
            logger.info("Final Max Result = %s", str(max_acc))
            results_mem_size = (max_acc, max_lr)
            temp_result.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Temp Results = %s", str(results_mem_size))

            my_experiment.results["Temp Results"] = temp_result
            my_experiment.store_json()
            print("LR RESULTS = ", temp_result)

        from scipy import stats
        best_lr = float(stats.mode(lr_all)[0][0])

        logger.info("BEST LR %s= ", str(best_lr))

        for aoo in range(0, args['runs']):

            keep = np.random.choice(list(range(650)), tot_class, replace=False)

            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=True,
                                              background=False), keep)
            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset,
                                           False,
                                           classes=total_clases),
                batch_size=1,
                shuffle=args['iid'],
                num_workers=2)
            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=not args['test'],
                                              background=False), keep)
            iterator = torch.utils.data.DataLoader(dataset,
                                                   batch_size=32,
                                                   shuffle=False,
                                                   num_workers=1)

            for mem_size in [args['memory']]:
                max_acc = -10
                max_lr = -10

                lr = best_lr

                # for lr in [0.001, 0.0003, 0.0001, 0.00003, 0.00001]:
                maml = load_model(args, config)
                maml = maml.to(device)

                opt = torch.optim.Adam(maml.get_adaptation_parameters(), lr=lr)

                for _ in range(0, 1):
                    for img, y in iterator_sorted:
                        img = img.to(device)
                        y = y.to(device)

                        pred = maml(img)
                        opt.zero_grad()
                        loss = F.cross_entropy(pred, y)
                        loss.backward()
                        opt.step()

                logger.info("Result after one epoch for LR = %f", lr)
                correct = 0
                for img, target in iterator:
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml(img,
                                    vars=None,
                                    bn_training=False,
                                    feature=False)

                    pred_q = (logits_q).argmax(dim=1)

                    correct += torch.eq(pred_q, target).sum().item() / len(img)

                logger.info(str(correct / len(iterator)))
                if (correct / len(iterator) > max_acc):
                    max_acc = correct / len(iterator)
                    max_lr = lr

                lr_list = [max_lr]
                results_mem_size = (max_acc, max_lr)
                logger.info("Final Max Result = %s", str(max_acc))
            final_results_all.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Final results = %s", str(results_mem_size))

            my_experiment.results["Final Results"] = final_results_all
            my_experiment.store_json()
            print("FINAL RESULTS = ", final_results_all)
Exemplo n.º 23
0
def main(args):
    # Set random seeds
    old_losses = [2.**30,2**30]
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    my_experiment = experiment(args.name, args, "./results/", commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")
    print(args)

    tasks = list(range(400))
    logger = logging.getLogger('experiment')

    sampler = ts.SamplerFactory.get_sampler("Sin", tasks, None, capacity=args.capacity + 1)

    config = mf.ModelFactory.get_model(args.model, "Sin", in_channels=args.capacity + 1, num_actions=1,
                                       width=args.width)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = MetaLearnerRegression(args, config).to(device)

    for name, param in maml.named_parameters(): #get names of paramters and parameters themselves
        param.learn = True #WHAT IS DIS
    for name, param in maml.net.named_parameters(): #get names of paramters and parameters themselves
        param.learn = True #WHAT IS DIS

    #Get list of parameters that are NOT frozen
    tmp = filter(lambda x: x.requires_grad, maml.parameters()) 
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logger.info(maml)
    logger.info('Total trainable tensors: %d', num)
    #
    accuracy = 0

    # FREZE LAYERS OF RLN
    frozen_layers = []
    for temp in range(args.rln * 2): #What's RLN
        
        frozen_layers.append("net.vars." + str(temp)) #layer name is net.vars.1, net.vars.2, etc
    logger.info("Frozen layers = %s", " ".join(frozen_layers))
    
    for step in range(args.epoch): #one epoch
        if step == 0: #if initial step, record frozen layers
            for name, param in maml.named_parameters():
                logger.info(name)
                if name in frozen_layers:
                    logger.info("Freeezing name %s", str(name))
                    param.learn = False
                    logger.info(str(param.requires_grad))

            for name, param in maml.net.named_parameters(): #what's dif from above?
                logger.info(name)
                if name in frozen_layers:
                    logger.info("Freeezing name %s", str(name))
                    param.learn = False
                    logger.info(str(param.requires_grad))

        # randomly seelect the specified numer of tasks from list of all tasks
        t1 = np.random.choice(tasks, args.tasks, replace=False) #sample WITHOUT replacement

        #for each task selected, get iterator info for the task
        iterators = []
        for t in t1: 
            # print(sampler.sample_task([t]))
            iterators.append(sampler.sample_task([t]))

        x_traj, y_traj, x_rand, y_rand = construct_set(iterators, sampler, steps=args.update_step)
        if torch.cuda.is_available():
            x_traj, y_traj, x_rand, y_rand = x_traj.cuda(), y_traj.cuda(), x_rand.cuda(), y_rand.cuda()
        
        #predict on trajectory AND random data stream
        # calls forward() method in MetaLearnerRegresssion class!!!
        # Updates the RLN RLN RLN in the process!!! (STEP 4)
        accs = maml(x_traj, y_traj, x_rand, y_rand,old_losses,args) 

        
        # Compute gradients for this loss wrt initial parameters to update initial parameters
        # Update initial paramteres theta, W?????
        # Is tis the meta-update? HELP?!?!
        # I THINK THIS IS THE UPDTE TO TLN TLN TLN IN STEP 4--> (STEP 4 END)
        maml.meta_optim.step() # STEP FOUR??? HELP

        # Monitoring
        if step in [0, 2000, 3000, 4000]:
            for param_group in maml.optimizer.param_groups:
                logger.info("Learning Rate at step %d = %s", step, str(param_group['lr']))

        accuracy = accuracy * 0.95 + 0.05 * accs[-1] #"averaging" the accuracy #WHY DO DIS
        if step % 5 == 0:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            writer.add_scalar('/metatrain/train/runningaccuracy', accuracy, step)
            logger.info("Running average of accuracy = %s", str(accuracy))
            logger.info('step: %d \t training acc (first, last) %s', step, str(accs[0]) + "," + str(accs[-1]))

        if step % 100 == 0: #if step is multiple of 100
            counter = 0
            for name, _ in maml.net.named_parameters():
                counter += 1

            for lrs in [args.update_lr]: #WhY HAVE MORE THAN ONE LERARNING RTE?? # USE 0.003, aka doing inner updates
                lr_results = {}
                lr_results[lrs] = []
                for temp in range(0, 20):
                    t1 = np.random.choice(tasks, args.tasks, replace=False)
                    iterators = []
                    #
                    for t in t1:
                        iterators.append(sampler.sample_task([t]))

                    # Step 1 on flowchart: Sample trajectory (X_traj, Y_traj) and random batch D_rand = (X_rand, Y_rand)
                    x_traj, y_traj, x_rand, y_rand = construct_set(iterators, sampler, steps=40)
                    if torch.cuda.is_available():
                        x_traj, y_traj, x_rand, y_rand = x_traj.cuda(), y_traj.cuda(), x_rand.cuda(), y_rand.cuda()

                    net = copy.deepcopy(maml.net) #copy the TLN??
                    net = net.to(device) #port copy to the GPU
                    for params_old, params_new in zip(maml.net.parameters(), net.parameters()):
                        params_new.learn = params_old.learn #set parameters of net (the copy) to have same 'learn' value as actual net

                    list_of_params = list(filter(lambda x: x.learn, net.parameters())) #get paramteres of copy fw 'learn' is on/True

                    optimizer = optim.SGD(list_of_params, lr=lrs)
                    
                    #Step 2 in flowchart figure 6 page 12
                    #Do k gradient updates on the TLN (W's), using MSE loss and 0.003 LR
                    # This is the INNER UPDATE I THINK!!!
                    for k in range(len(x_traj)): 
                        logits = net(x_traj[k], None, bn_training=False)

                        logits_select = []
                        for no, val in enumerate(y_traj[k, :, 1].long()):
                            logits_select.append(logits[no, val])

                        logits = torch.stack(logits_select).unsqueeze(1)

                        loss = F.mse_loss(logits, y_traj[k, :, 0].unsqueeze(1))
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step() #UPDATE THE TLN
                    
                    #STep 3 of flowchart
                    # Use updated network to compute loss on random batch, add to the list of losses/loss results
                    with torch.no_grad():
                        #Use the updated network to predict on the random batch of data
                        logits = net(x_rand[0], vars=None, bn_training=False) 

                        logits_select = []
                        for no, val in enumerate(y_rand[0, :, 1].long()):
                            logits_select.append(logits[no, val])
                        logits = torch.stack(logits_select).unsqueeze(1)
                        loss_q = F.mse_loss(logits, y_rand[0, :, 0].unsqueeze(1))
                        lr_results[lrs].append(loss_q.item())

                logger.info("Avg MSE LOSS  for lr %s = %s", str(lrs), str(np.mean(lr_results[lrs])))

        if args.smart and args.use_mini:
            torch.save(maml.net, my_experiment.path + "regression_smart_little.model")
        elif args.smart and not args.use_mini:
            torch.save(maml.net, my_experiment.path + "regression_smart_big.model")
        else:
            torch.save(maml.net, my_experiment.path + "regression_model.model")
Exemplo n.º 24
0
def main():
    p = reg_parser.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    run = p.parse_known_args()[0].run
    all_args = vars(p.parse_known_args()[0])

    args = utils.get_run(all_args, run)

    my_experiment = experiment(args["name"],
                               args,
                               args["output_dir"],
                               sql=True,
                               run=int(run / total_seeds),
                               seed=total_seeds)

    my_experiment.results["all_args"] = all_args

    my_experiment.make_table("metrics", {
        "run": 0,
        "meta_loss": 0.0,
        "step": 0
    }, ("run", "step"))

    metrics_keys = ["run", "meta_loss", "step"]

    logger = logging.getLogger('experiment')
    tasks = list(range(400))

    sampler = ts.SamplerFactory.get_sampler("Sin",
                                            tasks,
                                            None,
                                            capacity=args["capacity"] + 1)
    model_config = mf.ModelFactory.get_model(args["model"],
                                             "Sin",
                                             input_dimension=args["capacity"] +
                                             1,
                                             output_dimension=1,
                                             width=args["width"],
                                             cols=args["cols"])
    gpu_to_use = run % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device('cuda:' + str(gpu_to_use))
        logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
    else:
        device = torch.device('cpu')
    if args.get('update_rule') == "RTRL":
        logger.info("Columnar Net based gradient approximation...")
        metalearner = MetaLearnerRegressionCol(args,
                                               model_config,
                                               device=device).to(device)
    else:
        logger.info("BPTT update rule...")
        metalearner = MetaLearnerRegression(args, model_config,
                                            device=device).to(device)
    tmp = filter(lambda x: x.requires_grad, metalearner.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logger.info('Total trainable tensors: %d', num)

    running_meta_loss = 0
    adaptation_loss = 0
    loss_history = []
    metrics_list = []
    metrics_keys = ["run", "meta_loss", "step"]
    adaptation_loss_history = []
    adaptation_running_loss_history = []
    meta_steps_counter = 0
    LOG_INTERVAL = 2
    for step in range(args["epoch"]):
        if step % LOG_INTERVAL == 0:
            logger.debug("####\t STEP %d \t####", step)
        net = metalearner.net
        meta_steps_counter += 1
        t1 = np.random.choice(tasks, args["tasks"], replace=False)
        iterators = []

        for t in t1:
            iterators.append(sampler.sample_task([t]))

        x_traj_meta, y_traj_meta, x_rand_meta, y_rand_meta = utils.construct_set(
            iterators, sampler, steps=1)
        x_traj_meta, x_rand_meta, y_traj_meta, y_rand_meta = x_traj_meta.view(
            -1, 51), x_rand_meta.view(-1, 51), y_traj_meta.view(
                -1, 2), y_rand_meta.view(-1, 2)
        if torch.cuda.is_available():
            x_traj_meta, y_traj_meta, x_rand_meta, y_rand_meta = x_traj_meta.to(
                device), y_traj_meta.to(device), x_rand_meta.to(
                    device), y_rand_meta.to(device)

        meta_loss = metalearner(x_traj_meta, y_traj_meta, x_rand_meta,
                                y_rand_meta)
        loss_history.append(meta_loss[-1].detach().cpu().item())

        running_meta_loss = running_meta_loss * 0.97 + 0.03 * meta_loss[
            -1].detach().cpu()
        running_meta_loss_fixed = running_meta_loss / (1 -
                                                       (0.97**
                                                        (meta_steps_counter)))
        metrics_list.append((run, running_meta_loss_fixed.item(), step))

        if step % LOG_INTERVAL == 0:
            if running_meta_loss > 0:
                logger.info("Running meta loss = %f",
                            running_meta_loss_fixed.item())

            with torch.no_grad():
                t1 = np.random.choice(tasks, args["tasks"], replace=False)

                iterators = []
                for t in t1:
                    iterators.append(sampler.sample_task([t]))

                x_traj, y_traj, x_rand, y_rand = utils.construct_set(iterators,
                                                                     sampler,
                                                                     steps=1)
                x_traj, x_rand, y_traj, y_rand = x_traj.view(
                    -1,
                    51), x_rand.view(-1,
                                     51), y_traj.view(-1,
                                                      2), y_rand.view(-1, 2)
                if torch.cuda.is_available():
                    x_traj, y_traj, x_rand, y_rand = x_traj.to(
                        device), y_traj.to(device), x_rand.to(
                            device), y_rand.to(device)
                logits_select = []
                for i in range(len(x_rand)):
                    l, _, _ = net.forward_col(x_rand[i], vars=None, grad=False)
                    logits_select.append(l)

                logits = torch.stack(logits_select).unsqueeze(1)

                current_adaptation_loss = F.mse_loss(logits,
                                                     y_rand[:, 0].unsqueeze(1))
                adaptation_loss_history.append(
                    current_adaptation_loss.detach().item())
                adaptation_loss = adaptation_loss * 0.97 + current_adaptation_loss.detach(
                ).cpu().item() * 0.03
                adaptation_loss_fixed = adaptation_loss / (1 -
                                                           (0.97**(step + 1)))
                adaptation_running_loss_history.append(adaptation_loss_fixed)

                logger.info("Adaptation loss = %f", current_adaptation_loss)

                if step % LOG_INTERVAL == 0:
                    logger.info("Running adaptation loss = %f",
                                adaptation_loss_fixed)

        if (step + 1) % (LOG_INTERVAL * 500) == 0:
            if not args["no_save"]:
                torch.save(metalearner.net, my_experiment.path + "net.model")
            dict_names = {}
            for (name, param) in metalearner.net.named_parameters():
                dict_names[name] = param.adaptation

            my_experiment.insert_values("metrics", metrics_keys, metrics_list)
            metrics_list = []

            my_experiment.add_result("Layers meta values", dict_names)
            my_experiment.add_result("Meta loss", loss_history)
            my_experiment.add_result("Adaptation loss",
                                     adaptation_loss_history)
            my_experiment.add_result("Running adaption loss",
                                     adaptation_running_loss_history)
            my_experiment.store_json()
Exemplo n.º 25
0
def main(args):
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    np.random.seed(args.seed)

    my_experiment = experiment(args.name, args, "../results/")

    dataset = df.DatasetFactory.get_dataset(args.dataset)
    dataset_test = df.DatasetFactory.get_dataset(args.dataset, train=False)

    if args.dataset == "CUB":
        args.classes = list(range(100))
    if args.dataset == "CIFAR100":
        args.classes = list(range(50))

    if args.dataset == "omniglot":
        iterator = torch.utils.data.DataLoader(utils.remove_classes_omni(
            dataset, list(range(963))),
                                               batch_size=256,
                                               shuffle=True,
                                               num_workers=1)
        iterator_test = torch.utils.data.DataLoader(utils.remove_classes_omni(
            dataset_test, list(range(963))),
                                                    batch_size=256,
                                                    shuffle=True,
                                                    num_workers=1)
    else:
        iterator = torch.utils.data.DataLoader(utils.remove_classes(
            dataset, args.classes),
                                               batch_size=12,
                                               shuffle=True,
                                               num_workers=1)
        iterator_test = torch.utils.data.DataLoader(utils.remove_classes(
            dataset_test, args.classes),
                                                    batch_size=12,
                                                    shuffle=True,
                                                    num_workers=1)

    logger.info(str(args))

    config = mf.ModelFactory.get_model("na", args.dataset)

    maml = learner.Learner(config).to(device)

    opt = torch.optim.Adam(maml.parameters(), lr=args.lr)

    for e in range(args.epoch):
        correct = 0
        for img, y in tqdm(iterator):
            if e == 50:
                opt = torch.optim.Adam(maml.parameters(), lr=0.00001)
                logger.info("Changing LR from %f to %f", 0.0001, 0.00001)
            img = img.to(device)
            y = y.to(device)
            pred = maml(img)
            feature = maml(img, feature=True)
            loss_rep = torch.abs(feature).sum()

            opt.zero_grad()
            loss = F.cross_entropy(pred, y)
            # loss_rep.backward(retain_graph=True)
            # logger.info("L1 norm = %s", str(loss_rep.item()))
            loss.backward()
            opt.step()
            correct += (pred.argmax(1) == y).sum().float() / len(y)
        logger.info("Accuracy at epoch %d = %s", e,
                    str(correct / len(iterator)))

        correct = 0
        with torch.no_grad():
            for img, y in tqdm(iterator_test):

                img = img.to(device)
                y = y.to(device)
                pred = maml(img)
                feature = maml(img, feature=True)
                loss_rep = torch.abs(feature).sum()

                correct += (pred.argmax(1) == y).sum().float() / len(y)
            logger.info("Accuracy Test at epoch %d = %s", e,
                        str(correct / len(iterator_test)))

        torch.save(maml, my_experiment.path + "model.net")
def main(args):
    # Seed random number generators
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    my_experiment = experiment(args.name,
                               args,
                               "/data5/jlindsey/continual/results",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")
    print(args)

    # Initalize tasks; we sample 1000 tasks for evaluation
    tasks = list(range(1000))
    logger = logging.getLogger('experiment')

    sampler = ts.SamplerFactory.get_sampler("Sin",
                                            tasks,
                                            None,
                                            None,
                                            capacity=args.capacity + 1)

    #config = mf.ModelFactory.get_model("na", "Sin", in_channels=args.capacity + 1, num_actions=args.tasks)

    config = mf.ModelFactory.get_model(args.modeltype,
                                       "Sin",
                                       in_channels=args.capacity + 1,
                                       num_actions=1,
                                       width=args.width)
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    # Load the model
    maml = MetaLearnerRegression(args, config).to(device)
    maml.net = torch.load(args.model, map_location='cpu').to(device)

    for name, param in maml.named_parameters():
        param.learn = True
    for name, param in maml.net.named_parameters():
        param.learn = True

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logger.info(maml)
    logger.info('Total trainable tensors: %d', num)

    ##### Setting up parameters for freezing RLN layers
    #### Also resets TLN layers with random initialization if args.reset is true
    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("net.vars." + str(temp))

    for name, param in maml.named_parameters():
        logger.info(name)
        if name in frozen_layers:
            logger.info("Freeezing name %s", str(name))
            param.learn = False
            logger.info(str(param.requires_grad))
        else:
            if args.reset:
                w = nn.Parameter(torch.ones_like(param))
                if len(w.shape) > 1:
                    logger.info("Resseting layer %s", str(name))
                    torch.nn.init.kaiming_normal_(w)
                else:
                    w = nn.Parameter(torch.zeros_like(param))
                param.data = w
                param.learn = True

    for name, param in maml.net.named_parameters():
        logger.info(name)
        if name in frozen_layers:
            logger.info("Freeezing name %s", str(name))
            param.learn = False
            logger.info(str(param.requires_grad))

    correct = 0
    counter = 0
    for name, _ in maml.net.named_parameters():
        # logger.info("LRs of layer %s = %s", str(name), str(torch.mean(maml.lrs[counter])))
        counter += 1

    for lrs in [0.003]:
        loss_vector = np.zeros(args.tasks)
        loss_vector_results = []
        lr_results = {}
        incremental_results = {}
        lr_results[lrs] = []

        runs = args.runs

        loss_hist = []

        for temp in range(0, runs):
            loss_vector = np.zeros(args.tasks)
            t1 = np.random.choice(tasks, args.tasks, replace=False)
            print(t1)

            loss_hist.append([])

            iterators = []
            for t in t1:
                iterators.append(sampler.sample_task([t]))
            if args.vary_length:
                num_steps = np.random.randint(args.update_step // 10,
                                              args.update_step + 1)
                x_spt, y_spt, x_qry, y_qry = construct_set(iterators,
                                                           sampler,
                                                           steps=num_steps,
                                                           iid=args.iid)
            else:
                num_steps = args.update_step
                x_spt, y_spt, x_qry, y_qry = construct_set(
                    iterators, sampler, steps=args.update_step, iid=args.iid)
            if torch.cuda.is_available():
                x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
                ), x_qry.cuda(), y_qry.cuda()

            net = copy.deepcopy(maml.net)
            net = net.to(device)
            for params_old, params_new in zip(maml.net.parameters(),
                                              net.parameters()):
                params_new.learn = params_old.learn

            list_of_params = list(filter(lambda x: x.learn, net.parameters()))

            optimizer = optim.SGD(list_of_params, lr=lrs)

            counter = 0
            x_spt_test, y_spt_test, x_qry_test, y_qry_test = construct_set(
                iterators, sampler, steps=300)
            if args.train_performance:
                x_spt_test, y_spt_test, x_qry_test, y_qry_test = x_spt, y_spt, x_qry, y_qry
                x_qry_test, y_qry_test = x_spt_test, y_spt_test
            if torch.cuda.is_available():
                x_spt_test, y_spt_test, x_qry_test, y_qry_test = x_spt_test.cuda(
                ), y_spt_test.cuda(), x_qry_test.cuda(), y_qry_test.cuda()

            fast_weights = net.vars
            if args.randomize_plastic_weights:
                net.randomize_plastic_weights()
            if args.zero_plastic_weights:
                net.zero_plastic_weights()
            for k in range(len(x_spt)):
                #print('hey', k, torch.sum(fast_weights[0]), torch.sum(fast_weights[14]))
                if k % num_steps == 0 and k > 0:
                    counter += 1
                    loss_temp = 0
                    if not counter in incremental_results:
                        incremental_results[counter] = []
                    with torch.no_grad():
                        if args.train_performance:
                            for update_upto in range(0, k):
                                logits = net(x_spt_test[update_upto],
                                             vars=fast_weights,
                                             bn_training=False)

                                logits_select = []
                                for no, val in enumerate(
                                        y_spt_test[update_upto, :, 1].long()):
                                    logits_select.append(logits[no, val])
                                logits = torch.stack(logits_select).unsqueeze(
                                    1)
                                loss_temp += F.mse_loss(
                                    logits, y_spt_test[update_upto, :,
                                                       0].unsqueeze(1))

                            loss_temp = loss_temp / (k)
                        else:
                            for update_upto in range(0, counter * 300):
                                logits = net(x_spt_test[update_upto],
                                             vars=fast_weights,
                                             bn_training=False)

                                logits_select = []
                                for no, val in enumerate(
                                        y_spt_test[update_upto, :, 1].long()):
                                    logits_select.append(logits[no, val])
                                logits = torch.stack(logits_select).unsqueeze(
                                    1)
                                loss_temp += F.mse_loss(
                                    logits, y_spt_test[update_upto, :,
                                                       0].unsqueeze(1))

                            loss_temp = loss_temp / (counter * 300)
                        incremental_results[counter].append(loss_temp.item())
                        my_experiment.results[
                            "incremental"] = incremental_results

                logits = net(x_spt[k], vars=fast_weights, bn_training=False)

                logits_select = []
                for no, val in enumerate(y_spt[k, :, 1].long()):
                    logits_select.append(logits[no, val])

                logits = torch.stack(logits_select).unsqueeze(1)
                loss = F.mse_loss(logits, y_spt[k, :, 0].unsqueeze(1))

                loss_hist[temp].append(loss.cpu().detach().numpy())

                grad = torch.autograd.grad(loss, fast_weights)
                # fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))

                #print('heyyy', args.plastic_update, args.update_lr, np.array(net.vars_plasticity[-2].cpu().detach()),  np.array(net.vars_plasticity[-1].cpu().detach()))
                if args.plastic_update:
                    fast_weights = list(
                        map(
                            lambda p: p[1] - p[0] * p[2]
                            if p[1].learn else p[1],
                            zip(grad, fast_weights, net.vars_plasticity)))
                elif args.layer_level_plastic_update:
                    fast_weights = list(
                        map(
                            lambda p: p[1] - p[0] * p[2]
                            if p[1].learn else p[1],
                            zip(grad, fast_weights,
                                net.layer_level_vars_plasticity)))
                else:
                    fast_weights = list(
                        map(
                            lambda p: p[1] - args.update_lr * p[0]
                            if p[1].learn else p[1], zip(grad, fast_weights)))

                for params_old, params_new in zip(net.vars, fast_weights):
                    params_new.learn = params_old.learn
                    #print('param', params_new.learn)

                optimizer.zero_grad()
                #loss.backward()
                #optimizer.step()

            counter += 1
            loss_temp = 0
            if not counter in incremental_results:
                incremental_results[counter] = []
            with torch.no_grad():
                if args.train_performance:
                    for update_upto in range(0, k):
                        logits = net(x_spt_test[update_upto],
                                     vars=fast_weights,
                                     bn_training=False)

                        logits_select = []
                        for no, val in enumerate(y_spt_test[update_upto, :,
                                                            1].long()):
                            logits_select.append(logits[no, val])
                        logits = torch.stack(logits_select).unsqueeze(1)
                        loss_temp += F.mse_loss(
                            logits, y_spt_test[update_upto, :, 0].unsqueeze(1))
                        # lr_results[lrs].append(loss_q.item())
                    loss_temp = loss_temp / (k)
                else:
                    for update_upto in range(0, counter * 300):
                        logits = net(x_spt_test[update_upto],
                                     vars=fast_weights,
                                     bn_training=False)

                        logits_select = []
                        for no, val in enumerate(y_spt_test[update_upto, :,
                                                            1].long()):
                            logits_select.append(logits[no, val])
                        logits = torch.stack(logits_select).unsqueeze(1)
                        loss_temp += F.mse_loss(
                            logits, y_spt_test[update_upto, :, 0].unsqueeze(1))
                        # lr_results[lrs].append(loss_q.item())
                    loss_temp = loss_temp / (counter * 300)
                incremental_results[counter].append(loss_temp.item())
                my_experiment.results["incremental"] = incremental_results
            #
            x_spt, y_spt, x_qry, y_qry = x_spt_test, y_spt_test, x_qry_test, y_qry_test
            if torch.cuda.is_available():
                x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
                ), x_qry.cuda(), y_qry.cuda()
            with torch.no_grad():
                logits = net(x_qry[0], vars=fast_weights, bn_training=False)

                logits_select = []
                for no, val in enumerate(y_qry[0, :, 1].long()):
                    logits_select.append(logits[no, val])
                logits = torch.stack(logits_select).unsqueeze(1)
                loss_q = F.mse_loss(logits, y_qry[0, :, 0].unsqueeze(1))
                print('loss', loss_q.item())
                lr_results[lrs].append(loss_q.item())

            counter = 0
            loss = 0

            for k in range(len(x_spt)):

                logits = net(x_spt[k], vars=fast_weights, bn_training=False)

                logits_select = []
                for no, val in enumerate(y_spt[k, :, 1].long()):
                    logits_select.append(logits[no, val])
                logits = torch.stack(logits_select).unsqueeze(1)

                loss_vector[int(counter / (300))] += F.mse_loss(
                    logits, y_spt[k, :, 0].unsqueeze(1)) / 300

                counter += 1
            loss_vector_results.append(loss_vector.tolist())

        np.save("loss_hist_" + args.orig_name + ".npy", loss_hist)
        logger.info("Loss vector all %s", str(loss_vector_results))
        logger.info("Avg MSE LOSS  for lr %s = %s", str(lrs),
                    str(np.mean(lr_results[lrs])))
        logger.info("Std MSE LOSS  for lr %s = %s", str(lrs),
                    str(np.std(lr_results[lrs])))
        loss_vector = loss_vector / runs
        print("Loss vector = ", loss_vector)
        my_experiment.results[str(lrs)] = str(loss_vector_results)
        my_experiment.store_json()
        np.save('evals/loss_vector_results_' + args.orig_name + '.npy',
                loss_vector_results)
        np.save('evals/final_results_' + args.orig_name + '.npy', lr_results)
        np.save('evals/incremental_results_' + args.orig_name + '.npy',
                incremental_results)
Exemplo n.º 27
0
def main(args):
    # Seed random number generators
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    my_experiment = experiment(args.name, args, "../results/", commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")
    print(args)

    # Initalize tasks; we sample 1000 tasks for evaluation
    tasks = list(range(1000))
    logger = logging.getLogger('experiment')

    sampler = ts.SamplerFactory.get_sampler("Sin", tasks, None, None, capacity=args.capacity + 1)

    config = mf.ModelFactory.get_model("na", "Sin", in_channels=args.capacity + 1, num_actions=args.tasks)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    # Load the model
    maml = MetaLearnerRegression(args, config).to(device)
    maml.net = torch.load(args.model, map_location='cpu').to(device)

    for name, param in maml.named_parameters():
        param.learn = True
    for name, param in maml.net.named_parameters():
        param.learn = True

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logger.info(maml)
    logger.info('Total trainable tensors: %d', num)

    ##### Setting up parameters for freezing RLN layers
    #### Also resets TLN layers with random initialization if args.reset is true
    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("net.vars." + str(temp))

    for name, param in maml.named_parameters():
        logger.info(name)
        if name in frozen_layers:
            logger.info("Freeezing name %s", str(name))
            param.learn = False
            logger.info(str(param.requires_grad))
        else:
            if args.reset:
                w = nn.Parameter(torch.ones_like(param))
                if len(w.shape) > 1:
                    logger.info("Resseting layer %s", str(name))
                    torch.nn.init.kaiming_normal_(w)
                else:
                    w = nn.Parameter(torch.zeros_like(param))
                param.data = w
                param.learn = True

    for name, param in maml.net.named_parameters():
        logger.info(name)
        if name in frozen_layers:
            logger.info("Freeezing name %s", str(name))
            param.learn = False
            logger.info(str(param.requires_grad))

    correct = 0
    counter = 0
    for name, _ in maml.net.named_parameters():
        # logger.info("LRs of layer %s = %s", str(name), str(torch.mean(maml.lrs[counter])))
        counter += 1

    for lrs in [0.003]:
        loss_vector = np.zeros(args.tasks)
        loss_vector_results = []
        lr_results = {}
        incremental_results = {}
        lr_results[lrs] = []

        runs = args.runs
        for temp in range(0, runs):
            loss_vector = np.zeros(args.tasks)
            t1 = np.random.choice(tasks, args.tasks, replace=False)
            print(t1)

            iterators = []
            for t in t1:
                iterators.append(sampler.sample_task([t]))
            x_spt, y_spt, x_qry, y_qry = construct_set(iterators, sampler, steps=args.update_step, iid=args.iid)
            if torch.cuda.is_available():
                x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(), x_qry.cuda(), y_qry.cuda()

            net = copy.deepcopy(maml.net)
            net = net.to(device)
            for params_old, params_new in zip(maml.net.parameters(), net.parameters()):
                params_new.learn = params_old.learn

            list_of_params = list(filter(lambda x: x.learn, net.parameters()))

            optimizer = optim.SGD(list_of_params, lr=lrs)

            counter = 0
            x_spt_test, y_spt_test, x_qry_test, y_qry_test = construct_set(iterators, sampler, steps=300)
            if torch.cuda.is_available():
                x_spt_test, y_spt_test, x_qry_test, y_qry_test = x_spt_test.cuda(), y_spt_test.cuda(), x_qry_test.cuda(), y_qry_test.cuda()
            for k in range(len(x_spt)):
                if k % args.update_step == 0 and k > 0:
                    counter += 1
                    loss_temp = 0
                    if not counter in incremental_results:
                        incremental_results[counter] = []
                    with torch.no_grad():
                        for update_upto in range(0, counter * 300):
                            logits = net(x_spt_test[update_upto], vars=None, bn_training=False)

                            logits_select = []
                            for no, val in enumerate(y_spt_test[update_upto, :, 1].long()):
                                logits_select.append(logits[no, val])
                            logits = torch.stack(logits_select).unsqueeze(1)
                            loss_temp += F.mse_loss(logits, y_spt_test[update_upto, :, 0].unsqueeze(1))

                        loss_temp = loss_temp / (counter * 300)
                        incremental_results[counter].append(loss_temp.item())
                        my_experiment.results["incremental"] = incremental_results

                logits = net(x_spt[k], None, bn_training=False)

                logits_select = []
                for no, val in enumerate(y_spt[k, :, 1].long()):
                    logits_select.append(logits[no, val])

                logits = torch.stack(logits_select).unsqueeze(1)
                loss = F.mse_loss(logits, y_spt[k, :, 0].unsqueeze(1))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            counter += 1
            loss_temp = 0
            if not counter in incremental_results:
                incremental_results[counter] = []
            with torch.no_grad():
                for update_upto in range(0, counter * 300):
                    logits = net(x_spt_test[update_upto], vars=None, bn_training=False)

                    logits_select = []
                    for no, val in enumerate(y_spt_test[update_upto, :, 1].long()):
                        logits_select.append(logits[no, val])
                    logits = torch.stack(logits_select).unsqueeze(1)
                    loss_temp += F.mse_loss(logits, y_spt_test[update_upto, :, 0].unsqueeze(1))
                    # lr_results[lrs].append(loss_q.item())
                loss_temp = loss_temp / (counter * 300)
                incremental_results[counter].append(loss_temp.item())
                my_experiment.results["incremental"] = incremental_results
            #
            x_spt, y_spt, x_qry, y_qry = x_spt_test, y_spt_test, x_qry_test, y_qry_test
            if torch.cuda.is_available():
                x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(), x_qry.cuda(), y_qry.cuda()
            with torch.no_grad():
                logits = net(x_qry[0], vars=None, bn_training=False)

                logits_select = []
                for no, val in enumerate(y_qry[0, :, 1].long()):
                    logits_select.append(logits[no, val])
                logits = torch.stack(logits_select).unsqueeze(1)
                loss_q = F.mse_loss(logits, y_qry[0, :, 0].unsqueeze(1))
                lr_results[lrs].append(loss_q.item())

            counter = 0
            loss = 0

            for k in range(len(x_spt)):

                logits = net(x_spt[k], None, bn_training=False)

                logits_select = []
                for no, val in enumerate(y_spt[k, :, 1].long()):
                    logits_select.append(logits[no, val])
                logits = torch.stack(logits_select).unsqueeze(1)

                loss_vector[int(counter / (300))] += F.mse_loss(logits, y_spt[k, :, 0].unsqueeze(1)) / 300

                counter += 1
            loss_vector_results.append(loss_vector.tolist())

        logger.info("Loss vector all %s", str(loss_vector_results))
        logger.info("Avg MSE LOSS  for lr %s = %s", str(lrs), str(np.mean(lr_results[lrs])))
        logger.info("Std MSE LOSS  for lr %s = %s", str(lrs), str(np.std(lr_results[lrs])))
        loss_vector = loss_vector / runs
        print("Loss vector = ", loss_vector)
        my_experiment.results[str(lrs)] = str(loss_vector_results)
        my_experiment.store_json()
    torch.save(maml.net, my_experiment.path + "learner.model")
Exemplo n.º 28
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    my_experiment = experiment(args.name, args, "../results/", args.commit)

    logger = logging.getLogger('experiment')
    logger.setLevel(logging.INFO)

    total_clases = [900]

    keep = list(range(total_clases[0]))

    dataset = utils.remove_classes_omni(
        df.DatasetFactory.get_dataset("omniglot",
                                      train=True,
                                      path=args.data_path,
                                      all=True), keep)
    iterator_sorted = torch.utils.data.DataLoader(utils.iterator_sorter_omni(
        dataset, False, classes=total_clases),
                                                  batch_size=128,
                                                  shuffle=True,
                                                  num_workers=2)

    iterator = iterator_sorted

    print(args)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = torch.load(args.model, map_location='cpu')

    if args.scratch:
        config = mf.ModelFactory.get_model("na", args.dataset)
        maml = learner.Learner(config)

    maml = maml.to(device)

    reps = []
    counter = 0

    fig, axes = plt.subplots(9, 4)
    with torch.no_grad():
        for img, target in iterator:
            print(counter)

            img = img.to(device)
            target = target.to(device)
            # print(target)
            rep = maml(img, vars=None, bn_training=False, feature=True)
            rep = rep.view((-1, 32, 72)).detach().cpu().numpy()
            rep_instance = rep[0]
            if args.binary:
                rep_instance = (rep_instance > 0).astype(int)
            if args.max:
                rep = rep / np.max(rep)
            else:
                rep = (rep > 0).astype(int)
            if counter < 36:
                print("Adding plot")
                axes[int(counter / 4), counter % 4].imshow(rep_instance,
                                                           cmap=args.color)
                axes[int(counter / 4), counter % 4].set_yticklabels([])
                axes[int(counter / 4), counter % 4].set_xticklabels([])
                axes[int(counter / 4), counter % 4].set_aspect('equal')

            counter += 1
            reps.append(rep)

    plt.subplots_adjust(wspace=0.0, hspace=0.0)

    plt.savefig(my_experiment.path + "instance_" + str(counter) + ".pdf",
                format="pdf")
    plt.clf()

    rep = np.concatenate(reps)
    averge_activation = np.mean(rep, 0)
    plt.imshow(averge_activation, cmap=args.color)
    plt.colorbar()
    plt.clim(0, np.max(averge_activation))
    plt.savefig(my_experiment.path + "average_activation.pdf", format="pdf")
    plt.clf()
    instance_sparsity = np.mean((np.sum(np.sum(rep, 1), 1)) / (64 * 36))
    print("Instance sparisty = ", instance_sparsity)
    my_experiment.results["instance_sparisty"] = str(instance_sparsity)
    lifetime_sparsity = (np.sum(rep, 0) / len(rep)).flatten()
    mean_lifetime = np.mean(lifetime_sparsity)
    print("Lifetime sparsity = ", mean_lifetime)
    my_experiment.results["lifetime_sparisty"] = str(mean_lifetime)
    dead_neuros = float(np.sum(
        (lifetime_sparsity == 0).astype(int))) / len(lifetime_sparsity)
    print("Dead neurons percentange = ", dead_neuros)
    my_experiment.results["dead_neuros"] = str(dead_neuros)
    plt.hist(lifetime_sparsity)

    plt.savefig(my_experiment.path + "histogram.pdf", format="pdf")
    my_experiment.store_json()
Exemplo n.º 29
0
def main(args):
    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(963))

    args.traj_classes = list(range(int(963 / 2), 963))

    dataset = df.DatasetFactory.get_dataset(args.dataset,
                                            background=True,
                                            train=True,
                                            all=True)
    dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                 background=True,
                                                 train=False,
                                                 all=True)

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset_test)

    config = mf.ModelFactory.get_model("na", args.dataset)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = MetaLearingClassification(args, config).to(device)

    utils.freeze_layers(args.rln, maml)

    for step in range(args.steps):

        t1 = np.random.choice(args.traj_classes, args.tasks, replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)

        # Evaluation during training for sanity checks
        if step % 40 == 39:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
        if step % 300 == 299:
            utils.log_accuracy(maml, my_experiment, iterator_test, device,
                               writer, step)
            utils.log_accuracy(maml, my_experiment, iterator_train, device,
                               writer, step)
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    my_experiment = experiment(args.name, args, "./evals/", args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")


    ver = 0

    while os.path.exists(args.modelX + "_" + str(ver)):
        ver += 1
        
    args.modelX = args.modelX + "_" + str(ver-1) + "/learner.model"
    
    logger = logging.getLogger('experiment')
    logger.setLevel(logging.INFO)
    total_clases = 10

    total_ff_vars = 2*(6 + 2 + args.num_extra_dense_layers)

    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("vars." + str(temp))
        

    for temp in range(args.rln_end * 2):
        frozen_layers.append("net.vars." + str(total_ff_vars - 1 - temp))
    #logger.info("Frozen layers = %s", " ".join(frozen_layers))

    #
    final_results_all = []
    
    total_clases = [5]

    if args.twentyclass:
        total_clases = [20]

    if args.twotask:
        total_clases = [2, 10]
    if args.fiftyclass:
        total_clases = [50]
    if args.tenclass:
        total_clases = [10]
    if args.fiveclass:
        total_clases = [5]       
    print('yooo', total_clases)
    for tot_class in total_clases:
        
        avg_perf = 0.0
        print('TOT_CLASS', tot_class)
        lr_list = [0]#[0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001]
        for aoo in range(0, args.runs):
            #print('run', aoo)
            keep = np.random.choice(list(range(650)), tot_class, replace=False)
            if args.dataset == "imagenet":
                keep = np.random.choice(list(range(20)), tot_class, replace=False)
                
                dataset = imgnet.MiniImagenet(args.imagenet_path, mode='test', elem_per_class=30, classes=keep, seed=aoo)

                dataset_test = imgnet.MiniImagenet(args.imagenet_path, mode='test', elem_per_class=30, classes=keep, test=args.test, seed=aoo)


                iterator = torch.utils.data.DataLoader(dataset_test, batch_size=128,
                                                       shuffle=True, num_workers=1)
                iterator_sorted = torch.utils.data.DataLoader(dataset, batch_size=1,
                                                       shuffle=False, num_workers=1)
            if args.dataset == "omniglot":

                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot", train=True, background=False), keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter_omni(dataset, False, classes=total_clases),
                    batch_size=1,
                    shuffle=False, num_workers=2)
                dataset = utils.remove_classes_omni(
                    df.DatasetFactory.get_dataset("omniglot", train=not args.test, background=False), keep)
                iterator = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                       shuffle=False, num_workers=1)
            elif args.dataset == "CIFAR100":
                keep = np.random.choice(list(range(50, 100)), tot_class)
                dataset = utils.remove_classes(df.DatasetFactory.get_dataset(args.dataset, train=True), keep)
                iterator_sorted = torch.utils.data.DataLoader(
                    utils.iterator_sorter(dataset, False, classes=tot_class),
                    batch_size=16,
                    shuffle=False, num_workers=2)
                dataset = utils.remove_classes(df.DatasetFactory.get_dataset(args.dataset, train=False), keep)
                iterator = torch.utils.data.DataLoader(dataset, batch_size=128,
                                                       shuffle=False, num_workers=1)
            # sampler = ts.MNISTSampler(list(range(0, total_clases)), dataset)
            #
            #print(args)

            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')

            results_mem_size = {}

            #print("LEN", len(iterator_sorted))
            for mem_size in [args.memory]:
                max_acc = -10
                max_lr = -10
                for lr in lr_list:
                    #torch.cuda.empty_cache()
                    #print(lr)
                    # for lr in [0.001, 0.0003, 0.0001, 0.00003, 0.00001]:
                    maml = torch.load(args.modelX, map_location='cpu')

                    if args.scratch:
                        config = mf.ModelFactory.get_model(args.model_type, args.dataset)
                        maml = learner.Learner(config, lr)
                        # maml = MetaLearingClassification(args, config).to(device).net

                    #maml.update_lr = lr
                    maml = maml.to(device)

                    for name, param in maml.named_parameters():
                        param.learn = True

                    for name, param in maml.named_parameters():
                        #if name.find("feedback_strength_vars") != -1:
                        #    print(name, param)
                        if name in frozen_layers:
                            # logger.info("Freeezing name %s", str(name))
                            param.learn = False
                            # logger.info(str(param.requires_grad))
                        else:
                            if args.reset:
                                w = nn.Parameter(torch.ones_like(param))
                                # logger.info("W shape = %s", str(len(w.shape)))
                                if len(w.shape) > 1:
                                    torch.nn.init.kaiming_normal_(w)
                                else:
                                    w = nn.Parameter(torch.zeros_like(param))
                                param.data = w
                                param.learn = True

                    frozen_layers = []
                    for temp in range(args.rln * 2):
                        frozen_layers.append("vars." + str(temp))

                    #torch.nn.init.kaiming_normal_(maml.parameters()[-2])
                    #w = nn.Parameter(torch.zeros_like(maml.parameters()[-1]))
                    #maml.parameters()[-1].data = w

                    
                    for n, a in maml.named_parameters():
                        n = n.replace(".", "_")
                        # logger.info("Name = %s", n)
                        if n == "vars_"+str(14+2*args.num_extra_dense_layers):
                            pass
                            #w = nn.Parameter(torch.ones_like(a))
                            # logger.info("W shape = %s", str(w.shape))
                            #torch.nn.init.kaiming_normal_(w)
                            #a.data = w
                        if n == "vars_"+str(15+2*args.num_extra_dense_layers):
                            pass
                            #w = nn.Parameter(torch.zeros_like(a))
                            #a.data = w
                            
                    #for fv in maml.feedback_vars:
                    #    w = nn.Parameter(torch.zeros_like(fv))
                    #    fv.data = w   
                        
                    #for fv in maml.feedback_strength_vars:
                    #    w = nn.Parameter(torch.ones_like(fv))
                    #    fv.data = w                        

                    correct = 0

                    for img, target in iterator:
                        #print('size', target.size())
                        target = torch.tensor(np.array([list(keep).index(int(target.cpu().numpy()[i])) for i in range(target.size()[0])]))
                        with torch.no_grad():
                            img = img.to(device)
                            target = target.to(device)
                            logits_q = maml(img, vars=None, bn_training=False, feature=False)
                            pred_q = (logits_q).argmax(dim=1)
                            correct += torch.eq(pred_q, target).sum().item() / len(img)

                    #logger.info("Pre-epoch accuracy %s", str(correct / len(iterator)))

                    filter_list = ["vars.0", "vars.1", "vars.2", "vars.3", "vars.4", "vars.5"]

                    #logger.info("Filter list = %s", ",".join(filter_list))
                    list_of_names = list(
                        map(lambda x: x[1], list(filter(lambda x: x[0] not in filter_list, maml.named_parameters()))))

                    list_of_params = list(filter(lambda x: x.learn, maml.parameters()))
                    list_of_names = list(filter(lambda x: x[1].learn, maml.named_parameters()))
                    if args.scratch or args.no_freeze:
                        print("Empty filter list")
                        list_of_params = maml.parameters()
                    #
                    #for x in list_of_names:
                    #    logger.info("Unfrozen layer = %s", str(x[0]))
                    opt = torch.optim.Adam(list_of_params, lr=lr)

                    fast_weights = None
                    if args.randomize_plastic_weights:
                        maml.randomize_plastic_weights()
                    if args.zero_plastic_weights:
                        maml.zero_plastic_weights()
                    res_sampler = rep.ReservoirSampler(mem_size)
                    iterator_sorted_new = []
                    iter_count = 0
                    for img, y in iterator_sorted:
                        
                        y = torch.tensor(np.array([list(keep).index(int(y.cpu().numpy()[i])) for i in range(y.size()[0])]))
                        if iter_count % 15 >= args.shots:
                            iter_count += 1
                            continue       
                        iterator_sorted_new.append((img, y))
                        iter_count += 1
                    iterator_sorted = []
                    perm = np.random.permutation(len(iterator_sorted_new))
                    for i in range(len(iterator_sorted_new)):
                        if args.iid:
                            iterator_sorted.append(iterator_sorted_new[perm[i]])
                        else:
                            iterator_sorted.append(iterator_sorted_new[i])

                    for iter in range(0, args.epoch):
                        iter_count = 0
                        imgs = []
                        ys = []
                            
                        for img, y in iterator_sorted:
                            
                            #print('iter count', iter_count)
                            #print('y is', y)
                            


                            #if iter_count % 15 >= args.shots:
                            #    iter_count += 1
                            #    continue
                            iter_count += 1
                            #with torch.no_grad():
                            if args.memory == 0:
                                  img = img.to(device)
                                  y = y.to(device)
                            else:
                                  res_sampler.update_buffer(zip(img, y))
                                  res_sampler.update_observations(len(img))
                                  img = img.to(device)
                                  y = y.to(device)
                                  img2, y2 = res_sampler.sample_buffer(8)
                                  img2 = img2.to(device)
                                  y2 = y2.to(device)
                                  img = torch.cat([img, img2], dim=0)
                                  y = torch.cat([y, y2], dim=0)
                                  #print('img size', img.size())

                            imgs.append(img)
                            ys.append(y)
                            if not args.batch_learning:
                                  logits = maml(img, vars=fast_weights)
                                  fast_weights = maml.getOjaUpdate(y, logits, fast_weights, hebbian=args.hebb)
                        if args.batch_learning:
                              y = torch.cat(ys, 0)
                              img = torch.cat(imgs, 0)
                              logits = maml(img, vars=fast_weights)
                              fast_weights = maml.getOjaUpdate(y, logits, fast_weights, hebbian=args.hebb)


                    #logger.info("Result after one epoch for LR = %f", lr)
                    correct = 0
                    for img, target in iterator:
                        target = torch.tensor(np.array([list(keep).index(int(target.cpu().numpy()[i])) for i in range(target.size()[0])]))
                        img = img.to(device)
                        target = target.to(device)
                        logits_q = maml(img, vars=fast_weights, bn_training=False, feature=False)

                        pred_q = (logits_q).argmax(dim=1)

                        correct += torch.eq(pred_q, target).sum().item() / len(img)

                    #logger.info(str(correct / len(iterator)))
                    if (correct / len(iterator) > max_acc):
                        max_acc = correct / len(iterator)
                        max_lr = lr
                        
                    del maml
                    #del maml
                    #del fast_weights

                lr_list = [max_lr]
                #print('result', max_acc)
                results_mem_size[mem_size] = (max_acc, max_lr)
                #logger.info("Final Max Result = %s", str(max_acc))
                writer.add_scalar('/finetune/best_' + str(aoo), max_acc, tot_class)
                avg_perf += max_acc / args.runs #TODO: change this if/when I ever use memory -- can't choose max memory size differently for each run!
                print('avg perf', avg_perf * args.runs / (1+aoo))
            final_results_all.append((tot_class, results_mem_size))
            #writer.add_scalar('performance', avg_perf, tot_class)
            #print("A=  ", results_mem_size)
            #logger.info("Final results = %s", str(results_mem_size))

            my_experiment.results["Final Results"] = final_results_all
            my_experiment.store_json()
            np.save('evals/final_results_'+args.orig_name+'.npy', final_results_all) 
            #print("FINAL RESULTS = ", final_results_all)
    writer.close()