コード例 #1
0
ファイル: oml_omniglot.py プロジェクト: 0merjavaid/mrcl
def main():
    p = class_parser.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)


    utils.set_seed(args['seed'])

    my_experiment = experiment(args['name'], args, "../results/", commit_changes=False, rank=0, seed=1)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args['classes'] = list(range(963))

    args['traj_classes'] = list(range(int(963/2), 963))


    dataset = df.DatasetFactory.get_dataset(args['dataset'], background=True, train=True,path=args["path"], all=True)
    dataset_test = df.DatasetFactory.get_dataset(args['dataset'], background=True, train=False, path=args["path"], all=True)

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test, batch_size=5,
                                                shuffle=True, num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset, batch_size=5,
                                                 shuffle=True, num_workers=1)

    sampler = ts.SamplerFactory.get_sampler(args['dataset'], args['classes'], dataset, dataset_test)

    config = mf.ModelFactory.get_model("na", args['dataset'], output_dimension=1000)

    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device('cuda:' + str(gpu_to_use))
        logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
    else:
        device = torch.device('cpu')

    maml = MetaLearingClassification(args, config).to(device)


    for step in range(args['steps']):

        t1 = np.random.choice(args['traj_classes'], args['tasks'], replace=False)
        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))
        d_rand_iterator = sampler.get_complete_iterator()
        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(d_traj_iterators, d_rand_iterator,
                                                               steps=args['update_step'], reset=not args['no_reset'])
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)
コード例 #2
0
def main(args):
    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(963))

    args.traj_classes = list(range(int(963 / 2), 963))

    dataset = df.DatasetFactory.get_dataset(args.dataset,
                                            background=True,
                                            train=True,
                                            all=True)
    dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                 background=True,
                                                 train=False,
                                                 all=True)

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset_test)

    config = mf.ModelFactory.get_model("na", args.dataset)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = MetaLearingClassification(args, config).to(device)

    utils.freeze_layers(args.rln, maml)

    for step in range(args.steps):

        t1 = np.random.choice(args.traj_classes, args.tasks, replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)

        # Evaluation during training for sanity checks
        if step % 40 == 39:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
        if step % 300 == 299:
            utils.log_accuracy(maml, my_experiment, iterator_test, device,
                               writer, step)
            utils.log_accuracy(maml, my_experiment, iterator_train, device,
                               writer, step)
コード例 #3
0
def main():
    p = class_parser.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)

    utils.set_seed(args["seed"])

    if args["log_root"]:
        log_root = osp.join("./results", args["log_root"]) + "/"
    else:
        log_root = osp.join("./results/")

    my_experiment = experiment(
        args["name"],
        args,
        log_root,
        commit_changes=False,
        rank=0,
        seed=args["seed"],
    )
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger("experiment")

    # Using first 963 classes of the omniglot as the meta-training set
    # args["classes"] = list(range(963))
    args["classes"] = list(range(args["num_classes"]))
    print("Using classes:", args["num_classes"])
    # logger.info("Using classes:", str(args["num_classes"]))

    # args["traj_classes"] = list(range(int(963 / 2), 963))

    if torch.cuda.is_available():
        device = torch.device("cuda")
        use_cuda = True
    else:
        device = torch.device("cpu")
        use_cuda = False
    dataset_spt = df.DatasetFactory.get_dataset(
        args["dataset"],
        background=True,
        train=True,
        path=args["path"],
        # all=True,
        # all=False,
        all=args["all"],
        prefetch_gpu=args["prefetch_gpu"],
        device=device,
        resize=args["resize"],
        augment=args["augment_spt"],
    )
    dataset_qry = df.DatasetFactory.get_dataset(
        args["dataset"],
        background=True,
        train=True,
        path=args["path"],
        # all=True,
        # all=False,
        all=args["all"],
        prefetch_gpu=args["prefetch_gpu"],
        device=device,
        resize=args["resize"],
        augment=args["augment_qry"],
    )
    dataset_test = df.DatasetFactory.get_dataset(
        args["dataset"],
        background=True,
        train=False,
        path=args["path"],
        # all=True,
        # all=False,
        all=args["all"],
        resize=args["resize"],
        # augment=args["augment"],
    )

    logger.info(
        f"Support size: {len(dataset_spt)}, Query size: {len(dataset_qry)}, test size: {len(dataset_test)}"
    )
    # print(f"Support size: {len(dataset_spt)}, Query size: {len(dataset_qry)}, test size: {len(dataset_test)}")

    pin_memory = use_cuda
    if args["prefetch_gpu"]:
        num_workers = 0
        pin_memory = False
    else:
        num_workers = args["num_workers"]
    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=5,
        shuffle=True,
        num_workers=0,
        # pin_memory=pin_memory,
    )

    iterator_train = torch.utils.data.DataLoader(
        dataset_spt,
        batch_size=5,
        shuffle=True,
        num_workers=0,
        # pin_memory=pin_memory,
    )

    logger.info("Support sampler:")
    sampler_spt = ts.SamplerFactory.get_sampler(
        args["dataset"],
        args["classes"],
        dataset_spt,
        dataset_test,
        prefetch_gpu=args["prefetch_gpu"],
        use_cuda=use_cuda,
        num_workers=0,
    )
    logger.info("Query sampler:")
    sampler_qry = ts.SamplerFactory.get_sampler(
        args["dataset"],
        args["classes"],
        dataset_qry,
        dataset_test,
        prefetch_gpu=args["prefetch_gpu"],
        use_cuda=use_cuda,
        num_workers=0,
    )

    config = mf.ModelFactory.get_model(
        "na",
        args["dataset"],
        output_dimension=1000,
        resize=args["resize"],
    )

    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device("cuda:" + str(gpu_to_use))
        logger.info("Using gpu : %s", "cuda:" + str(gpu_to_use))
    else:
        device = torch.device("cpu")

    maml = MetaLearingClassification(args, config).to(device)

    for step in range(args["steps"]):

        t1 = np.random.choice(args["classes"], args["tasks"], replace=False)

        d_traj_iterators_spt = []
        d_traj_iterators_qry = []
        for t in t1:
            d_traj_iterators_spt.append(sampler_spt.sample_task([t]))
            d_traj_iterators_qry.append(sampler_qry.sample_task([t]))

        d_rand_iterator = sampler_spt.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data_paper(
            d_traj_iterators_spt,
            d_traj_iterators_qry,
            d_rand_iterator,
            steps=args["update_step"],
            reset=not args["no_reset"],
        )
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = (
                x_spt.to(device),
                y_spt.to(device),
                x_qry.to(device),
                y_qry.to(device),
            )

        #
        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)

        # Evaluation during training for sanity checks
        if step % 40 == 5:
            writer.add_scalar("/metatrain/train/accuracy", accs[-1], step)
            writer.add_scalar("/metatrain/train/loss", loss[-1], step)
            writer.add_scalar("/metatrain/train/accuracy0", accs[0], step)
            writer.add_scalar("/metatrain/train/loss0", loss[0], step)
            logger.info("step: %d \t training acc %s", step, str(accs))
            logger.info("step: %d \t training loss %s", step, str(loss))
        # Currently useless
        if (step % 300 == 3) or ((step + 1) == args["steps"]):
            torch.save(maml.net, my_experiment.path + "learner.model")
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    my_experiment = experiment(args.name,
                               args,
                               "/data5/jlindsey/continual/results",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    args.classes = list(range(963))

    print('dataset', args.dataset, args.dataset == "imagenet")

    if args.dataset != "imagenet":

        dataset = df.DatasetFactory.get_dataset(args.dataset,
                                                background=True,
                                                train=True,
                                                all=True)
        dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                     background=True,
                                                     train=False,
                                                     all=True)

    else:
        args.classes = list(range(64))
        dataset = imgnet.MiniImagenet(args.imagenet_path, mode='train')
        dataset_test = imgnet.MiniImagenet(args.imagenet_path, mode='test')

    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    logger.info("Train set length = %d", len(iterator_train) * 5)
    logger.info("Test set length = %d", len(iterator_test) * 5)
    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset_test)

    config = mf.ModelFactory.get_model(
        args.model_type,
        args.dataset,
        width=args.width,
        num_extra_dense_layers=args.num_extra_dense_layers)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    if args.oja or args.hebb:
        maml = OjaMetaLearingClassification(args, config).to(device)
    else:
        print('starting up')
        maml = MetaLearingClassification(args, config).to(device)

    import sys
    if args.from_saved:
        maml.net = torch.load(args.model)
        if args.use_derivative:
            maml.net.use_derivative = True
        maml.net.optimize_out = args.optimize_out
        if maml.net.optimize_out:
            maml.net.feedback_strength_vars.append(
                torch.nn.Parameter(maml.net.init_feedback_strength *
                                   torch.ones(1).cuda()))

        if args.reset_feedback_strength:
            for fv in maml.net.feedback_strength_vars:
                w = nn.Parameter(torch.ones_like(fv) * args.feedback_strength)
                fv.data = w

        if args.reset_feedback_vars:

            print('howdy', maml.net.num_feedback_layers)

            maml.net.feedback_vars = nn.ParameterList()
            maml.net.feedback_vars_bundled = []

            maml.net.vars_plasticity = nn.ParameterList()
            maml.net.plasticity = nn.ParameterList()

            maml.net.neuron_plasticity = nn.ParameterList()

            maml.net.layer_plasticity = nn.ParameterList()

            starting_width = 84
            cur_width = starting_width
            num_outputs = maml.net.config[-1][1][0]
            for i, (name, param) in enumerate(maml.net.config):
                print('yo', i, name, param)
                if name == 'conv2d':
                    print('in conv2d')
                    stride = param[4]
                    padding = param[5]

                    #print('cur_width', cur_width, param[3])
                    cur_width = (cur_width + 2 * padding - param[3] +
                                 stride) // stride

                    maml.net.vars_plasticity.append(
                        nn.Parameter(torch.ones(*param[:4]).cuda()))
                    maml.net.vars_plasticity.append(
                        nn.Parameter(torch.ones(param[0]).cuda()))
                    #self.activations_list.append([])
                    maml.net.plasticity.append(
                        nn.Parameter(
                            maml.net.init_plasticity *
                            torch.ones(param[0], param[1] * param[2] *
                                       param[3]).cuda()))  #not implemented
                    maml.net.neuron_plasticity.append(
                        nn.Parameter(torch.zeros(1).cuda()))  #not implemented

                    maml.net.layer_plasticity.append(
                        nn.Parameter(maml.net.init_plasticity *
                                     torch.ones(1).cuda()))  #not implemented

                    feedback_var = []

                    for fl in range(maml.net.num_feedback_layers):
                        print('doing fl')
                        in_dim = maml.net.width
                        out_dim = maml.net.width
                        if fl == maml.net.num_feedback_layers - 1:
                            out_dim = param[0] * cur_width * cur_width
                        if fl == 0:
                            in_dim = num_outputs
                        feedback_w_shape = [out_dim, in_dim]
                        feedback_w = nn.Parameter(
                            torch.ones(feedback_w_shape).cuda())
                        feedback_b = nn.Parameter(torch.zeros(out_dim).cuda())
                        torch.nn.init.kaiming_normal_(feedback_w)
                        feedback_var.append((feedback_w, feedback_b))
                        print('adding')
                        maml.net.feedback_vars.append(feedback_w)
                        maml.net.feedback_vars.append(feedback_b)

                    #maml.net.feedback_vars_bundled.append(feedback_var)
                    #maml.net.feedback_vars_bundled.append(None)#bias feedback -- not implemented

                    #'''

                    maml.net.feedback_vars_bundled.append(
                        nn.Parameter(torch.zeros(
                            1)))  #weight feedback -- not implemented
                    maml.net.feedback_vars_bundled.append(
                        nn.Parameter(
                            torch.zeros(1)))  #bias feedback -- not implemented

                elif name == 'linear':
                    maml.net.vars_plasticity.append(
                        nn.Parameter(torch.ones(*param).cuda()))
                    maml.net.vars_plasticity.append(
                        nn.Parameter(torch.ones(param[0]).cuda()))
                    #self.activations_list.append([])
                    maml.net.plasticity.append(
                        nn.Parameter(maml.net.init_plasticity *
                                     torch.ones(*param).cuda()))
                    maml.net.neuron_plasticity.append(
                        nn.Parameter(maml.net.init_plasticity *
                                     torch.ones(param[0]).cuda()))
                    maml.net.layer_plasticity.append(
                        nn.Parameter(maml.net.init_plasticity *
                                     torch.ones(1).cuda()))

                    feedback_var = []

                    for fl in range(maml.net.num_feedback_layers):
                        in_dim = maml.net.width
                        out_dim = maml.net.width
                        if fl == maml.net.num_feedback_layers - 1:
                            out_dim = param[0]
                        if fl == 0:
                            in_dim = num_outputs
                        feedback_w_shape = [out_dim, in_dim]
                        feedback_w = nn.Parameter(
                            torch.ones(feedback_w_shape).cuda())
                        feedback_b = nn.Parameter(torch.zeros(out_dim).cuda())
                        torch.nn.init.kaiming_normal_(feedback_w)
                        feedback_var.append((feedback_w, feedback_b))
                        maml.net.feedback_vars.append(feedback_w)
                        maml.net.feedback_vars.append(feedback_b)
                    maml.net.feedback_vars_bundled.append(feedback_var)
                    maml.net.feedback_vars_bundled.append(
                        None)  #bias feedback -- not implemented

        maml.init_stuff(args)

    maml.net.optimize_out = args.optimize_out
    if maml.net.optimize_out:
        maml.net.feedback_strength_vars.append(
            torch.nn.Parameter(maml.net.init_feedback_strength *
                               torch.ones(1).cuda()))
    #I recently un-indented this until the maml.init_opt() line.  If stuff stops working, try re-indenting this block
    if args.zero_non_output_plasticity:
        for index in range(len(maml.net.vars_plasticity) - 2):
            maml.net.vars_plasticity[index] = torch.nn.Parameter(
                maml.net.vars_plasticity[index] * 0)
        if args.oja or args.hebb:
            for index in range(len(maml.net.plasticity) - 1):
                if args.plasticity_rank1:
                    maml.net.plasticity[index] = torch.nn.Parameter(
                        torch.zeros(1).cuda())
                else:
                    maml.net.plasticity[index] = torch.nn.Parameter(
                        maml.net.plasticity[index] * 0)
                    maml.net.layer_plasticity[index] = torch.nn.Parameter(
                        maml.net.layer_plasticity[index] * 0)
                    maml.net.neuron_plasticity[index] = torch.nn.Parameter(
                        maml.net.neuron_plasticity[index] * 0)

        if args.oja or args.hebb:
            for index in range(len(maml.net.vars_plasticity) - 2):
                maml.net.vars_plasticity[index] = torch.nn.Parameter(
                    maml.net.vars_plasticity[index] * 0)
    if args.zero_all_plasticity:
        print('zeroing plasticity')
        for index in range(len(maml.net.vars_plasticity)):
            maml.net.vars_plasticity[index] = torch.nn.Parameter(
                maml.net.vars_plasticity[index] * 0)
        for index in range(len(maml.net.plasticity)):
            if args.plasticity_rank1:
                maml.net.plasticity[index] = torch.nn.Parameter(
                    torch.zeros(1).cuda())
            else:

                maml.net.plasticity[index] = torch.nn.Parameter(
                    maml.net.plasticity[index] * 0)
                maml.net.layer_plasticity[index] = torch.nn.Parameter(
                    maml.net.layer_plasticity[index] * 0)
                maml.net.neuron_plasticity[index] = torch.nn.Parameter(
                    maml.net.neuron_plasticity[index] * 0)

    print('heyy', maml.net.feedback_vars)
    maml.init_opt()
    for name, param in maml.named_parameters():
        param.learn = True
    for name, param in maml.net.named_parameters():
        param.learn = True

    if args.freeze_out_plasticity:
        maml.net.plasticity[-1].requires_grad = False
    total_ff_vars = 2 * (6 + 2 + args.num_extra_dense_layers)
    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("net.vars." + str(temp))

    for temp in range(args.rln_end * 2):
        frozen_layers.append("net.vars." + str(total_ff_vars - 1 - temp))
    for name, param in maml.named_parameters():
        # logger.info(name)
        if name in frozen_layers:
            logger.info("RLN layer %s", str(name))
            param.learn = False

    # Update the classifier
    list_of_params = list(filter(lambda x: x.learn, maml.parameters()))
    list_of_names = list(filter(lambda x: x[1].learn, maml.named_parameters()))

    for a in list_of_names:
        logger.info("TLN layer = %s", a[0])

    for step in range(args.steps):
        '''
        print('plasticity')
        for p in maml.net.plasticity:
            print(p.size(), torch.sum(p), p)
        '''
        t1 = np.random.choice(
            args.classes, args.tasks, replace=False
        )  #np.random.randint(1, args.tasks + 1), replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            iid=args.iid)

        perm = np.random.permutation(args.tasks)

        old = []
        for i in range(y_spt.size()[0]):
            num = int(y_spt[i].cpu().numpy())
            if num not in old:
                old.append(num)
            y_spt[i] = torch.tensor(perm[old.index(num)])

        for i in range(y_qry.size()[1]):
            num = int(y_qry[0][i].cpu().numpy())
            y_qry[0][i] = torch.tensor(perm[old.index(num)])
        #print('hi', y_qry.size())
        #print('y_spt', y_spt)
        #print('y_qry', y_qry)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        #print('heyyyy', x_spt.size(), y_spt.size(), x_qry.size(), y_qry.size())
        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 1 == 0:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
        if step % 300 == 0:
            correct = 0
            torch.save(maml.net, my_experiment.path + "learner.model")
            for img, target in iterator_test:
                with torch.no_grad():
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml.net(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    correct += torch.eq(pred_q, target).sum().item() / len(img)
            writer.add_scalar('/metatrain/test/classifier/accuracy',
                              correct / len(iterator_test), step)
            logger.info("Test Accuracy = %s",
                        str(correct / len(iterator_test)))
            correct = 0
            for img, target in iterator_train:
                with torch.no_grad():

                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml.net(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)
                    pred_q = (logits_q).argmax(dim=1)
                    correct += torch.eq(pred_q, target).sum().item() / len(img)

            logger.info("Train Accuracy = %s",
                        str(correct / len(iterator_train)))
            writer.add_scalar('/metatrain/train/classifier/accuracy',
                              correct / len(iterator_train), step)
コード例 #5
0
ファイル: maml-rep_omniglot.py プロジェクト: zoq/mrcl
def main(args):
    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "../results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(963))

    args.traj_classes = list(range(963))
    #
    dataset = df.DatasetFactory.get_dataset(args.dataset,
                                            background=True,
                                            train=True,
                                            all=True)
    dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                 background=False,
                                                 train=True,
                                                 all=True)

    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset)

    sampler_test = ts.SamplerFactory.get_sampler(args.dataset,
                                                 list(range(600)),
                                                 dataset_test, dataset_test)

    config = mf.ModelFactory.get_model("na", "omniglot-fc")

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = MetaLearingClassification(args, config).to(device)

    utils.freeze_layers(args.rln, maml)

    for step in range(args.steps):

        t1 = np.random.choice(args.traj_classes, args.tasks, replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_few_shot_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)
        #
        # Evaluation during training for sanity checks
        if step % 20 == 0:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
            logger.info("Loss = %s", str(loss[-1].item()))
        if step % 600 == 599:
            torch.save(maml.net, my_experiment.path + "learner.model")
            accs_avg = None
            for temp_temp in range(0, 40):
                t1_test = np.random.choice(list(range(600)),
                                           args.tasks,
                                           replace=False)

                d_traj_test_iterators = []
                for t in t1_test:
                    d_traj_test_iterators.append(sampler_test.sample_task([t]))

                x_spt, y_spt, x_qry, y_qry = maml.sample_few_shot_training_data(
                    d_traj_test_iterators,
                    None,
                    steps=args.update_step,
                    reset=not args.no_reset)
                if torch.cuda.is_available():
                    x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
                    ), x_qry.cuda(), y_qry.cuda()

                accs, loss = maml.finetune(x_spt, y_spt, x_qry, y_qry)
                if accs_avg is None:
                    accs_avg = accs
                else:
                    accs_avg += accs
            logger.info("Loss = %s", str(loss[-1].item()))
            writer.add_scalar('/metatest/train/accuracy', accs_avg[-1] / 40,
                              step)
            logger.info('TEST: step: %d \t testing acc %s', step,
                        str(accs_avg / 40))
コード例 #6
0
def main(args):
    utils.set_seed(args.seed)

    my_experiment = experiment(args.name,
                               args,
                               "../results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger("experiment")

    # Using first 963 classes of the omniglot as the meta-training set
    args.classes = list(range(963))

    if torch.cuda.is_available():
        device = torch.device("cuda")
        use_cuda = True
    else:
        device = torch.device("cpu")
        use_cuda = False

    dataset = df.DatasetFactory.get_dataset(
        args.dataset,
        background=True,
        train=True,
        all=True,
        prefetch_gpu=args.prefetch_gpu,
        device=device,
    )
    dataset_test = dataset
    # dataset_test = df.DatasetFactory.get_dataset(
    #     args.dataset, background=True, train=False, all=True
    # )

    # Iterators used for evaluation
    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    sampler = ts.SamplerFactory.get_sampler(
        args.dataset,
        args.classes,
        dataset,
        dataset_test,
        prefetch_gpu=args.prefetch_gpu,
        use_cuda=use_cuda,
    )

    config = mf.ModelFactory.get_model(args.treatment, args.dataset)

    maml = MetaLearingClassification(args, config, args.treatment).to(device)

    if args.checkpoint:
        checkpoint = torch.load(args.saved_model, map_location="cpu")

        for idx in range(len(checkpoint)):
            maml.net.parameters()[idx].data = checkpoint.parameters()[idx].data

    maml = maml.to(device)

    utils.freeze_layers(args.rln, maml)

    for step in range(args.steps):

        t1 = np.random.choice(args.classes, args.tasks, replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators,
            d_rand_iterator,
            steps=args.update_step,
            reset=not args.no_reset,
        )
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = (
                x_spt.cuda(),
                y_spt.cuda(),
                x_qry.cuda(),
                y_qry.cuda(),
            )

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)  # , args.tasks)

        # Evaluation during training for sanity checks
        if step % 40 == 0:
            # writer.add_scalar('/metatrain/train/accuracy', accs, step)
            logger.info("step: %d \t training acc %s", step, str(accs))
        if step % 100 == 0 or step == 19999:
            torch.save(maml.net, args.model_name)
        if step % 2000 == 0 and step != 0:
            utils.log_accuracy(maml, my_experiment, iterator_test, device,
                               writer, step)
            utils.log_accuracy(maml, my_experiment, iterator_train, device,
                               writer, step)
コード例 #7
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    my_experiment = experiment(args.name,
                               args,
                               "../results/",
                               commit_changes=args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')

    args.classes = list(range(963))

    dataset = df.DatasetFactory.get_dataset(args.dataset,
                                            background=True,
                                            train=True)
    dataset_test = df.DatasetFactory.get_dataset(args.dataset,
                                                 background=True,
                                                 train=False)

    iterator_test = torch.utils.data.DataLoader(dataset_test,
                                                batch_size=5,
                                                shuffle=True,
                                                num_workers=1)

    iterator_train = torch.utils.data.DataLoader(dataset,
                                                 batch_size=5,
                                                 shuffle=True,
                                                 num_workers=1)

    logger.info("Train set length = %d", len(iterator_train) * 5)
    logger.info("Test set length = %d", len(iterator_test) * 5)
    sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes,
                                            dataset, dataset_test)

    config = mf.ModelFactory.get_model("na", args.dataset)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    maml = MetaLearingClassification(args, config).to(device)

    for name, param in maml.named_parameters():
        param.learn = True
    for name, param in maml.net.named_parameters():
        param.learn = True

    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("net.vars." + str(temp))

    for name, param in maml.named_parameters():
        logger.info(name)
        if name in frozen_layers:
            logger.info("Freeezing name %s", str(name))
            param.learn = False

    # Update the classifier
    list_of_params = list(filter(lambda x: x.learn, maml.parameters()))
    list_of_names = list(filter(lambda x: x[1].learn, maml.named_parameters()))

    for a in list_of_names:
        logger.info("Unfrozen layers for rep learning = %s", a[0])

    for step in range(args.epoch):

        t1 = np.random.choice(args.classes,
                              np.random.randint(1, args.tasks + 1),
                              replace=False)

        d_traj_iterators = []
        for t in t1:
            d_traj_iterators.append(sampler.sample_task([t]))

        d_rand_iterator = sampler.get_complete_iterator()

        x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(
            d_traj_iterators, d_rand_iterator, steps=args.update_step)
        if torch.cuda.is_available():
            x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(
            ), x_qry.cuda(), y_qry.cuda()

        accs, loss = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 40 == 0:
            writer.add_scalar('/metatrain/train/accuracy', accs[-1], step)
            logger.info('step: %d \t training acc %s', step, str(accs))
        if step % 300 == 0:
            correct = 0
            torch.save(maml.net, my_experiment.path + "learner.model")
            for img, target in iterator_test:
                with torch.no_grad():
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml.net(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    correct += torch.eq(pred_q, target).sum().item() / len(img)
            writer.add_scalar('/metatrain/test/classifier/accuracy',
                              correct / len(iterator_test), step)
            logger.info("Test Accuracy = %s",
                        str(correct / len(iterator_test)))
            correct = 0
            for img, target in iterator_train:
                with torch.no_grad():
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml.net(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)
                    pred_q = (logits_q).argmax(dim=1)
                    correct += torch.eq(pred_q, target).sum().item() / len(img)

            logger.info("Train Accuracy = %s",
                        str(correct / len(iterator_train)))
            writer.add_scalar('/metatrain/train/classifier/accuracy',
                              correct / len(iterator_train), step)