Ejemplo n.º 1
0
def test(args):
    # see if we already ran this experiment
    code_root = os.path.dirname(os.path.realpath(__file__))
    exp_dir = utils.get_path_from_args(
        args) if not args.output_dir else args.output_dir
    path = "{}/results/{}".format(code_root, exp_dir)
    assert os.path.isdir(path)
    task_family_test = tasks_sine.RegressionTasksSinusoidal(
        "test", args.skew_task_distribution)
    best_valid_model = utils.load_obj(os.path.join(path,
                                                   "logs")).best_valid_model
    k_shots = [5, 10, 20, 40]
    df = []
    for k_shot in k_shots:
        losses = np.array(
            eval(
                args,
                copy.copy(best_valid_model),
                task_family=task_family_test,
                num_updates=10,
                lr_inner=0.01,
                n_tasks=1000,
                k_shot=k_shot,
            ))
        for grad_step, task_losses in enumerate(losses.T, 1):
            new_rows = [[k_shot, grad_step, tl] for tl in task_losses]
            df.extend(new_rows)

    df = pd.DataFrame(df, columns=["k_shot", "grad_steps", "loss"])
    df.to_pickle(os.path.join(path, "res.pkl"))
    utils.plot_df(df, path)
Ejemplo n.º 2
0
    args.context_in_type = ['mix', 'mix', 'mix', 'mix']
    args.k_shot = 5
    args.lr_inner = 0.1
    args.num_grad_steps_inner = 2
    args.num_grad_steps_eval = 2
    args.num_context_params = 100

    if args.k_shot == 1:
        args.tasks_per_metaupdate = 2
        in_channel = 5
    else:
        args.tasks_per_metaupdate = 2
        in_channel = 8

    path = os.path.join(utils.get_base_path(), 'result_files',
                        utils.get_path_from_args(args))

    try:
        training_stats, validation_stats = np.load(path + '.npy',
                                                   allow_pickle=True)
        print('load path:[{}]'.format(path))
    except FileNotFoundError:
        print(
            'You need to run the experiments first and make sure the results are saved at {}'
            .format(path))
        raise FileNotFoundError

    Logger.print_header()

    for num_grad_steps in [2]:
Ejemplo n.º 3
0
def run(args, log_interval=5000, rerun=False):

    # see if we already ran this experiment
    code_root = os.path.dirname(os.path.realpath(__file__))
    exp_dir = utils.get_path_from_args(
        args) if not args.output_dir else args.output_dir
    path = "{}/results/{}".format(code_root, exp_dir)
    if not os.path.isdir(path):
        os.makedirs(path)

    if os.path.exists(os.path.join(path, "logs.pkl")) and not rerun:
        return utils.load_obj(os.path.join(path, "logs"))

    start_time = time.time()

    # correctly seed everything
    utils.set_seed(args.seed)

    # --- initialise everything ---
    task_family_train = tasks_sine.RegressionTasksSinusoidal(
        "train", args.skew_task_distribution)
    task_family_valid = tasks_sine.RegressionTasksSinusoidal(
        "valid", args.skew_task_distribution)

    # initialise network
    model_inner = MamlModel(
        task_family_train.num_inputs,
        task_family_train.num_outputs,
        n_weights=args.num_hidden_layers,
        device=args.device,
    ).to(args.device)
    model_outer = copy.deepcopy(model_inner)
    if args.detector == "minimax":
        task_sampler = TaskSampler(
            task_family_train.atoms //
            (2 if args.skew_task_distribution else 1)).to(args.device)
    elif args.detector == "neyman-pearson":
        constrainer = Constrainer(
            task_family_train.atoms //
            (2 if args.skew_task_distribution else 1)).to(args.device)

    # intitialise meta-optimiser
    meta_optimiser = optim.Adam(model_outer.weights + model_outer.biases,
                                args.lr_meta)

    # initialise loggers
    logger = Logger()
    logger.best_valid_model = copy.deepcopy(model_outer)

    for i_iter in range(args.n_iter):

        # copy weights of network
        copy_weights = [w.clone() for w in model_outer.weights]
        copy_biases = [b.clone() for b in model_outer.biases]

        # get all shared parameters and initialise cumulative gradient
        meta_gradient = [
            0 for _ in range(
                len(copy_weights + copy_biases) +
                (2 if args.detector != "bayes" else 0))
        ]

        # sample tasks
        if args.detector == "minimax":
            task_idxs, task_probs = task_sampler(args.tasks_per_metaupdate)
        elif args.detector == "neyman-pearson":
            amplitude_idxs = torch.randint(
                task_family_train.atoms //
                (2 if args.skew_task_distribution else 1),
                (args.tasks_per_metaupdate, ),
            )
            phase_idxs = torch.randint(
                task_family_train.atoms //
                (2 if args.skew_task_distribution else 1),
                (args.tasks_per_metaupdate, ),
            )
            task_idxs = amplitude_idxs, phase_idxs
        else:
            task_idxs = None

        target_functions = task_family_train.sample_tasks(
            args.tasks_per_metaupdate, task_idxs=task_idxs)

        for t in range(args.tasks_per_metaupdate):

            # reset network weights
            model_inner.weights = [w.clone() for w in copy_weights]
            model_inner.biases = [b.clone() for b in copy_biases]

            # get data for current task
            train_inputs = task_family_train.sample_inputs(
                args.k_meta_train).to(args.device)

            for _ in range(args.num_inner_updates):

                # make prediction using the current model
                outputs = model_inner(train_inputs)

                # get targets
                targets = target_functions[t](train_inputs)

                # ------------ update on current task ------------

                # compute loss for current task
                loss_task = F.mse_loss(outputs, targets)

                # compute the gradient wrt current model
                params = [w for w in model_inner.weights
                          ] + [b for b in model_inner.biases]
                grads = torch.autograd.grad(loss_task,
                                            params,
                                            create_graph=True,
                                            retain_graph=True)

                # make an update on the inner model using the current model (to build up computation graph)
                for i in range(len(model_inner.weights)):
                    if not args.first_order:
                        model_inner.weights[i] = (model_inner.weights[i] -
                                                  args.lr_inner * grads[i])
                    else:
                        model_inner.weights[i] = (
                            model_inner.weights[i] -
                            args.lr_inner * grads[i].detach())
                for j in range(len(model_inner.biases)):
                    if not args.first_order:
                        model_inner.biases[j] = (
                            model_inner.biases[j] -
                            args.lr_inner * grads[i + j + 1])
                    else:
                        model_inner.biases[j] = (
                            model_inner.biases[j] -
                            args.lr_inner * grads[i + j + 1].detach())

            # ------------ compute meta-gradient on test loss of current task ------------

            # get test data
            test_inputs = task_family_train.sample_inputs(args.k_meta_test).to(
                args.device)

            # get outputs after update
            test_outputs = model_inner(test_inputs)

            # get the correct targets
            test_targets = target_functions[t](test_inputs)

            # compute loss (will backprop through inner loop)
            if args.detector == "minimax":
                importance = task_probs[t]
            else:
                importance = 1.0 / args.tasks_per_metaupdate
            loss_meta_raw = F.mse_loss(test_outputs, test_targets)
            loss_meta = loss_meta_raw * importance
            if args.detector == "neyman-pearson":
                amplitude_idxs, phase_idxs = task_idxs
                aux_loss = constrainer(amplitude_idxs[t], phase_idxs[t],
                                       loss_meta_raw)
                loss_meta = loss_meta + aux_loss

            # compute gradient w.r.t. *outer model*
            outer_params = model_outer.weights + model_outer.biases
            if args.detector == "minimax":
                outer_params += [
                    task_sampler.tau_amplitude, task_sampler.tau_phase
                ]
            elif args.detector == "neyman-pearson":
                outer_params += [
                    constrainer.tau_amplitude, constrainer.tau_phase
                ]

            task_grads = torch.autograd.grad(
                loss_meta,
                outer_params,
                retain_graph=(args.detector != "bayes"))
            for i in range(len(outer_params)):
                meta_gradient[i] += task_grads[i].detach()

        # ------------ meta update ------------

        meta_optimiser.zero_grad()
        # print(meta_gradient)

        # assign meta-gradient
        for i in range(len(model_outer.weights)):
            model_outer.weights[i].grad = meta_gradient[i]
            meta_gradient[i] = 0
        for j in range(len(model_outer.biases)):
            model_outer.biases[j].grad = meta_gradient[i + j + 1]
            meta_gradient[i + j + 1] = 0
        if args.detector == "minimax":
            task_sampler.tau_amplitude.grad = -meta_gradient[i + j + 2]
            task_sampler.tau_phase.grad = -meta_gradient[i + j + 3]
            meta_gradient[i + j + 2] = 0
            meta_gradient[i + j + 3] = 0
        elif args.detector == "neyman-pearson":
            constrainer.tau_amplitude.grad = -meta_gradient[i + j + 2]
            constrainer.tau_phase.grad = -meta_gradient[i + j + 3]
            meta_gradient[i + j + 2] = 0
            meta_gradient[i + j + 3] = 0

        # do update step on outer model
        meta_optimiser.step()

        # ------------ logging ------------

        if i_iter % log_interval == 0:

            # evaluate on training set
            losses = eval(
                args,
                copy.copy(model_outer),
                task_family=task_family_train,
                num_updates=args.num_inner_updates,
            )
            loss_mean, loss_conf = utils.get_stats(np.array(losses))
            logger.train_loss.append(loss_mean)
            logger.train_conf.append(loss_conf)

            # evaluate on valid set
            losses = eval(
                args,
                copy.copy(model_outer),
                task_family=task_family_valid,
                num_updates=args.num_inner_updates,
            )
            loss_mean, loss_conf = utils.get_stats(np.array(losses))
            logger.valid_loss.append(loss_mean)
            logger.valid_conf.append(loss_conf)

            # save best model
            if logger.valid_loss[-1] == np.min(logger.valid_loss):
                print("saving best model at iter", i_iter)
                logger.best_valid_model = copy.copy(model_outer)

            # save logging results
            utils.save_obj(logger, os.path.join(path, "logs"))

            # print current results
            logger.print_info(i_iter, start_time)
            start_time = time.time()

    return logger
Ejemplo n.º 4
0
def run(args, num_workers=1, log_interval=100, verbose=True, save_path=None):
    """
    Run the model
    :param args: set of arguments
    :param num_workers: number of workers
    :param log_interval: logging interval
    """
    # Current root path
    code_root = os.path.dirname(os.path.realpath(__file__))
    if not os.path.isdir('{}/{}_result_files/'.format(code_root, args.task)):
        os.mkdir('{}/{}_result_files/'.format(code_root, args.task))

    # Result path
    path = '{}/{}_result_files/'.format(
        code_root, args.task) + utils.get_path_from_args(args)
    print('File saved in {}'.format(path))

    if os.path.exists(path + '.pkl') and not args.rerun:
        print('File has already existed. Try --rerun')
        return utils.load_obj(path)

    start_time = time.time()
    utils.set_seed(args.seed)

    # ---------------------------------------------------------
    # -------------------- Training ---------------------------

    # Initialize the model
    model = user_preference_estimator(args)
    # Train the model
    model.train()
    print(sum([param.nelement() for param in model.parameters()]))

    # Set up the meta-optimiser for model parameters
    meta_optimiser = torch.optim.Adam(model.parameters(), args.lr_meta)
    # Set up a scheduler for the meta-optimizer
    scheduler = torch.optim.lr_scheduler.StepLR(meta_optimiser, 5000,
                                                args.lr_meta_decay)

    # Initialize logger
    logger = Logger()
    logger.args = args

    # Initialize the starting point for the meta gradient (it's faster to copy this than to create new object)
    meta_grad_init = [0 for _ in range(len(model.state_dict()))]

    # Create a PyTorch data loader
    dataloader_train = DataLoader(DataPrep(args),
                                  batch_size=1,
                                  num_workers=args.num_workers)

    # Loop through all epochs
    for epoch in range(args.num_epoch):

        x_spt, y_spt, x_qry, y_qry = [], [], [], []
        iter_counter = 0
        for step, batch in enumerate(dataloader_train):
            if len(x_spt) < args.tasks_per_metaupdate:
                x_spt.append(batch[0][0])
                y_spt.append(batch[1][0])
                x_qry.append(batch[2][0])
                y_qry.append(batch[3][0])
                if not len(x_spt) == args.tasks_per_metaupdate:
                    continue

            if len(x_spt) != args.tasks_per_metaupdate:
                continue

            # Initialize the meta-gradient
            meta_grad = copy.deepcopy(meta_grad_init)
            loss_pre = []
            loss_after = []
            for i in range(args.tasks_per_metaupdate):
                loss_pre.append(F.mse_loss(model(x_qry[i]), y_qry[i]).item())
                fast_parameters = model.final_part.parameters()
                for weight in model.final_part.parameters():
                    weight.fast = None
                for k in range(args.num_grad_steps_inner):
                    logits = model(x_spt[i])
                    loss = F.mse_loss(logits, y_spt[i])
                    grad = torch.autograd.grad(loss,
                                               fast_parameters,
                                               create_graph=True)
                    fast_parameters = []
                    for k, weight in enumerate(model.final_part.parameters()):
                        if weight.fast is None:
                            weight.fast = weight - args.lr_inner * grad[
                                k]  # create weight.fast
                        else:
                            weight.fast = weight.fast - args.lr_inner * grad[k]
                        fast_parameters.append(weight.fast)

                logits_q = model(x_qry[i])
                # loss_q will be overwritten and just keep the loss_q on last update step.
                loss_q = F.mse_loss(logits_q, y_qry[i])
                loss_after.append(loss_q.item())
                task_grad_test = torch.autograd.grad(loss_q,
                                                     model.parameters())

                for g in range(len(task_grad_test)):
                    meta_grad[g] += task_grad_test[g].detach()

            # -------------- Meta Update --------------
            # Zero the gradients of meta-optimiser
            meta_optimiser.zero_grad()

            # Set gradients of model parameters manually
            for c, param in enumerate(model.parameters()):
                param.grad = meta_grad[c] / float(args.tasks_per_metaupdate)
                param.grad.data.clamp_(-10, 10)

            # The meta-optimiser only operates on the shared parameters, not the context parameters
            meta_optimiser.step()
            scheduler.step()
            x_spt, y_spt, x_qry, y_qry = [], [], [], []

            loss_pre = np.array(loss_pre)
            loss_after = np.array(loss_after)
            logger.train_loss.append(np.mean(loss_pre))
            logger.valid_loss.append(np.mean(loss_after))
            logger.train_conf.append(1.96 * np.std(loss_pre, ddof=0) /
                                     np.sqrt(len(loss_pre)))
            logger.valid_conf.append(1.96 * np.std(loss_after, ddof=0) /
                                     np.sqrt(len(loss_after)))
            logger.test_loss.append(0)
            logger.test_conf.append(0)

            # Save the logger object
            utils.save_obj(logger, path)
            # Print current results
            logger.print_info(epoch, iter_counter, start_time)
            # Initialize start time again
            start_time = time.time()
            # Increment iteration counter
            iter_counter += 1
        if epoch % 2 == 0:
            print('saving model at iter', epoch)
            logger.valid_model.append(copy.deepcopy(model))

    return logger, model
Ejemplo n.º 5
0
def run(args, log_interval=5000, rerun=False):
    global temp
    assert not args.maml
    # see if we already ran this experiment
    code_root = os.path.dirname(os.path.realpath(__file__))
    if not os.path.isdir('{}/{}_result_files/'.format(code_root, args.task)):
        os.mkdir('{}/{}_result_files/'.format(code_root, args.task))
    path = '{}/{}_result_files/'.format(
        code_root, args.task) + utils.get_path_from_args(args)

    if os.path.exists(path + '.pkl') and not rerun:
        return utils.load_obj(path)

    start_time = time.time()
    utils.set_seed(args.seed)

    # --- initialise everything ---

    # get the task family
    if args.task == 'sine':
        task_family_train = tasks_sine.RegressionTasksSinusoidal()
        task_family_valid = tasks_sine.RegressionTasksSinusoidal()
        task_family_test = tasks_sine.RegressionTasksSinusoidal()
    elif args.task == 'celeba':
        task_family_train = tasks_celebA.CelebADataset('train',
                                                       device=args.device)
        task_family_valid = tasks_celebA.CelebADataset('valid',
                                                       device=args.device)
        task_family_test = tasks_celebA.CelebADataset('test',
                                                      device=args.device)
    elif args.task == 'multi':
        task_family_train = multi()
        task_family_valid = multi()
        task_family_test = multi()
    else:
        raise NotImplementedError

    # initialise network
    model = CaviaModel(n_in=task_family_train.num_inputs,
                       n_out=task_family_train.num_outputs,
                       num_context_params=args.num_context_params,
                       n_hidden=args.num_hidden_layers,
                       device=args.device).to(args.device)
    # intitialise meta-optimiser
    # (only on shared params - context parameters are *not* registered parameters of the model)
    meta_optimiser = optim.Adam(model.parameters(), args.lr_meta)
    encoder = pool_encoder().to(args.device)
    encoder_optimiser = optim.Adam(encoder.parameters(), lr=1e-3)
    decoder = pool_decoder().to(args.device)
    decoder_optimiser = optim.Adam(decoder.parameters(), lr=1e-3)
    #encoder.load_state_dict(torch.load('./model/encoder'))
    p_encoder = place().to(args.device)
    p_optimiser = optim.Adam(p_encoder.parameters(), lr=1e-3)
    # initialise loggers
    logger = Logger()
    logger.best_valid_model = copy.deepcopy(model)

    # --- main training loop ---

    for i_iter in range(args.n_iter):
        # initialise meta-gradient
        meta_gradient = [0 for _ in range(len(model.state_dict()))]
        place_gradient = [0 for _ in range(len(p_encoder.state_dict()))]
        encoder_gradient = [0 for _ in range(len(encoder.state_dict()))]
        #print(meta_gradient)

        # sample tasks
        target_functions, ty = task_family_train.sample_tasks(
            args.tasks_per_metaupdate, True)

        # --- inner loop ---

        for t in range(args.tasks_per_metaupdate):

            # reset private network weights
            model.reset_context_params()

            # get data for current task
            x = task_family_train.sample_inputs(
                args.k_meta_train, args.use_ordered_pixels).to(args.device)

            y = target_functions[t](x)
            train_inputs = torch.cat([x, y], dim=1)
            a = encoder(train_inputs)
            #embedding,_ = torch.max(a,dim=0)
            embedding = torch.mean(a, dim=0)

            logits = p_encoder(embedding)
            logits = logits.reshape([latent_dim, categorical_dim])

            y = gumbel_softmax(logits, temp, hard=True)
            y = y[:, 1]
            #print(temp)

            #model.set_context_params(embedding)
            #print(model.context_params)

            for _ in range(args.num_inner_updates):
                # forward through model
                train_outputs = model(x)

                # get targets
                train_targets = target_functions[t](x)

                # ------------ update on current task ------------

                # compute loss for current task
                task_loss = F.mse_loss(train_outputs, train_targets)

                # compute gradient wrt context params
                task_gradients = \
                    torch.autograd.grad(task_loss, model.context_params, create_graph=not args.first_order)[0]

                # update context params (this will set up the computation graph correctly)
                model.context_params = model.context_params - args.lr_inner * task_gradients * y

            #print(model.context_params)
            # ------------ compute meta-gradient on test loss of current task ------------

            # get test data
            test_inputs = task_family_train.sample_inputs(
                args.k_meta_test, args.use_ordered_pixels).to(args.device)

            # get outputs after update
            test_outputs = model(test_inputs)

            # get the correct targets
            test_targets = target_functions[t](test_inputs)

            # compute loss after updating context (will backprop through inner loop)
            loss_meta = F.mse_loss(test_outputs, test_targets)
            #print(torch.norm(y,1)/1000)
            #loss_meta += torch.norm(y,1)/700
            qy = F.softmax(logits, dim=-1)
            log_ratio = torch.log(qy * categorical_dim + 1e-20)
            KLD = torch.sum(qy * log_ratio, dim=-1).mean() / 5
            # print(KLD)
            loss_meta += KLD

            # compute gradient + save for current task
            task_grad = torch.autograd.grad(loss_meta,
                                            model.parameters(),
                                            retain_graph=True)

            for i in range(len(task_grad)):
                # clip the gradient
                meta_gradient[i] += task_grad[i].detach().clamp_(-10, 10)

            task_grad_place = torch.autograd.grad(loss_meta,
                                                  p_encoder.parameters(),
                                                  retain_graph=True)

            for i in range(len(task_grad_place)):
                # clip the gradient
                place_gradient[i] += task_grad_place[i].detach().clamp_(
                    -10, 10)

            task_grad_encoder = torch.autograd.grad(loss_meta,
                                                    encoder.parameters())
            for i in range(len(task_grad_encoder)):
                # clip the gradient
                encoder_gradient[i] += task_grad_encoder[i].detach().clamp_(
                    -10, 10)

        # ------------ meta update ------------

        # assign meta-gradient
        for i, param in enumerate(model.parameters()):
            param.grad = meta_gradient[i] / args.tasks_per_metaupdate
        meta_optimiser.step()

        # do update step on shared model
        for i, param in enumerate(p_encoder.parameters()):
            param.grad = place_gradient[i] / args.tasks_per_metaupdate
        p_optimiser.step()

        for i, param in enumerate(encoder.parameters()):
            param.grad = encoder_gradient[i] / args.tasks_per_metaupdate
        encoder_optimiser.step()

        # reset context params
        model.reset_context_params()

        if i_iter % 350 == 1:
            temp = np.maximum(temp * np.exp(-ANNEAL_RATE * i_iter), 0.5)
            print(temp)
        # ------------ logging ------------

        if i_iter % log_interval == 0:

            # evaluate on training set
            loss_mean, loss_conf = eval_cavia(
                args,
                copy.deepcopy(model),
                task_family=task_family_train,
                num_updates=args.num_inner_updates,
                encoder=encoder,
                p_encoder=p_encoder)
            logger.train_loss.append(loss_mean)
            logger.train_conf.append(loss_conf)

            # evaluate on test set
            loss_mean, loss_conf = eval_cavia(
                args,
                copy.deepcopy(model),
                task_family=task_family_valid,
                num_updates=args.num_inner_updates,
                encoder=encoder,
                p_encoder=p_encoder)
            logger.valid_loss.append(loss_mean)
            logger.valid_conf.append(loss_conf)

            # evaluate on validation set

            if i_iter % log_interval == 0:
                loss_mean, loss_conf = eval_cavia(
                    args,
                    copy.deepcopy(model),
                    task_family=task_family_test,
                    num_updates=args.num_inner_updates,
                    encoder=encoder,
                    p_encoder=p_encoder)
                logger.test_loss.append(loss_mean)
                logger.test_conf.append(loss_conf)

            # save logging results
            utils.save_obj(logger, path)

            # save best model
            if logger.valid_loss[-1] == np.min(logger.valid_loss):
                print('saving best model at iter', i_iter)
                logger.best_valid_model = copy.deepcopy(model)
                logger.best_encoder_valid_model = copy.deepcopy(encoder)
                logger.best_place_valid_model = copy.deepcopy(p_encoder)

            if i_iter % (4 * log_interval) == 0:
                print('saving model at iter', i_iter)
                logger.valid_model.append(copy.deepcopy(model))
                logger.encoder_valid_model.append(copy.deepcopy(encoder))
                logger.place_valid_model.append(copy.deepcopy(p_encoder))

            # visualise results
            if args.task == 'celeba':
                task_family_train.visualise(
                    task_family_train, task_family_test,
                    copy.deepcopy(logger.best_valid_model), args, i_iter)

            # print current results
            logger.print_info(i_iter, start_time)
            start_time = time.time()

    return logger
Ejemplo n.º 6
0
def run(args, log_interval=5000, rerun=False):
    assert args.maml

    # see if we already ran this experiment
    code_root = os.path.dirname(os.path.realpath(__file__))
    if not os.path.isdir('{}/{}_result_files/'.format(code_root, args.task)):
        os.mkdir('{}/{}_result_files/'.format(code_root, args.task))
    path = '{}/{}_result_files/'.format(code_root, args.task) + utils.get_path_from_args(args)

    if os.path.exists(path + '.pkl') and not rerun:
        return utils.load_obj(path)

    start_time = time.time()

    # correctly seed everything
    utils.set_seed(args.seed)

    # --- initialise everything ---

    # get the task family
    if args.task == 'sine':
        task_family_train = tasks_sine.RegressionTasksSinusoidal()
        task_family_valid = tasks_sine.RegressionTasksSinusoidal()
        task_family_test = tasks_sine.RegressionTasksSinusoidal()
    elif args.task == 'celeba':
        task_family_train = tasks_celebA.CelebADataset('train', args.device)
        task_family_valid = tasks_celebA.CelebADataset('valid', args.device)
        task_family_test = tasks_celebA.CelebADataset('test', args.device)
    else:
        raise NotImplementedError

    #initialize transformer
    transformer = FCNet(task_family_train.num_inputs, 3, 128, 128).to(args.device)

    # initialise network
    model_inner = MamlModel(128,
                            task_family_train.num_outputs,
                            n_weights=args.num_hidden_layers,
                            num_context_params=args.num_context_params,
                            device=args.device
                            ).to(args.device)
    model_outer = copy.deepcopy(model_inner)
    
    print("MAML: ", model_outer)
    print("Transformer: ", transformer)
    # intitialise meta-optimiser
    meta_optimiser = optim.Adam(model_outer.weights + model_outer.biases + [model_outer.task_context],
                                args.lr_meta)
    opt_transformer = torch.optim.Adam(transformer.parameters(), 0.01)

    # initialise loggers
    logger = Logger()
    logger.best_valid_model = copy.deepcopy(model_outer)

    for i_iter in range(args.n_iter):
        #meta_train_error = 0.0
        # copy weights of network
        copy_weights = [w.clone() for w in model_outer.weights]
        copy_biases = [b.clone() for b in model_outer.biases]
        copy_context = model_outer.task_context.clone()

        # get all shared parameters and initialise cumulative gradient
        meta_gradient = [0 for _ in range(len(copy_weights + copy_biases) + 1)]

        # sample tasks
        target_functions = task_family_train.sample_tasks(args.tasks_per_metaupdate)

        for t in range(args.tasks_per_metaupdate):
            
            #gradient initialization for transformer
            acc_grads = fsn.phi_gradients(transformer, args.device)

            # reset network weights
            model_inner.weights = [w.clone() for w in copy_weights]
            model_inner.biases = [b.clone() for b in copy_biases]
            model_inner.task_context = copy_context.clone()

            # get data for current task
            train_inputs = task_family_train.sample_inputs(args.k_meta_train, args.use_ordered_pixels).to(args.device)

            # get test data
            test_inputs = task_family_train.sample_inputs(args.k_meta_test, args.use_ordered_pixels).to(args.device)

            transformed_train_inputs = transformer(train_inputs)#.to(args.device)
            transformed_test_inputs = transformer(test_inputs)#.to(args.device)

            # transformer task loss
           # with torch.no_grad():
            targets0 = target_functions[t](train_inputs)
            L0 = F.mse_loss(model_inner(transformed_train_inputs), targets0)
            targets1 = target_functions[t](test_inputs)
            L1 = F.mse_loss(model_inner(transformed_test_inputs), targets1)
            trans_loss = fsn.cosine_loss(L0, L1, model_inner, args.device)
                #trans_loss = evaluation_error + trans_loss
           
            for step in range(args.num_inner_updates):
               # print("iteration:" , i_iter, "innerstep: ", step)
                outputs = model_inner(transformed_train_inputs)

                # make prediction using the current model
                #outputs = model_inner(train_inputs)

                # get targets
                targets = target_functions[t](train_inputs)

                # ------------ update on current task ------------

                # compute loss for current task
                loss_task = F.mse_loss(outputs, targets)

                # compute the gradient wrt current model
                params = [w for w in model_inner.weights] + [b for b in model_inner.biases] + [model_inner.task_context]
                grads = torch.autograd.grad(loss_task, params, create_graph=True, retain_graph=True)

                # make an update on the inner model using the current model (to build up computation graph)
                for i in range(len(model_inner.weights)):
                    if not args.first_order:
                        model_inner.weights[i] = model_inner.weights[i] - args.lr_inner * grads[i].clamp_(-10, 10)
                    else:
                        model_inner.weights[i] = model_inner.weights[i] - args.lr_inner * grads[i].detach().clamp_(-10, 10)
                for j in range(len(model_inner.biases)):
                    if not args.first_order:
                        model_inner.biases[j] = model_inner.biases[j] - args.lr_inner * grads[i + j + 1].clamp_(-10, 10)
                    else:
                        model_inner.biases[j] = model_inner.biases[j] - args.lr_inner * grads[i + j + 1].detach().clamp_(-10, 10)
                if not args.first_order:
                    model_inner.task_context = model_inner.task_context - args.lr_inner * grads[i + j + 2].clamp_(-10, 10)
                else:
                    model_inner.task_context = model_inner.task_context - args.lr_inner * grads[i + j + 2].detach().clamp_(-10, 10)

            # ------------ compute meta-gradient on test loss of current task ------------

            # get outputs after update
            test_outputs = model_inner(transformed_test_inputs)

            # get the correct targets
            test_targets = target_functions[t](test_inputs)

            # compute loss (will backprop through inner loop)
            loss_meta = F.mse_loss(test_outputs, test_targets)


            #meta_train_error += loss_meta.item()

            # transformer gradients
            trans_loss = loss_meta
            grads_phi = list(torch.autograd.grad(trans_loss, transformer.parameters(), retain_graph=True, create_graph=True))

            for p, l in zip(acc_grads, grads_phi):
                l = l
                p.data = torch.add(p, (1 / args.tasks_per_metaupdate), l.detach().clamp_(-10,10))


            # compute gradient w.r.t. *outer model*
            task_grads = torch.autograd.grad(loss_meta,
                                             model_outer.weights + model_outer.biases + [model_outer.task_context])
            for i in range(len(model_inner.weights + model_inner.biases) + 1):
                meta_gradient[i] += task_grads[i].detach().clamp_(-10, 10)

        # ------------ meta update ------------

        opt_transformer.zero_grad()
        meta_optimiser.zero_grad()

        # parameter gradient attributes of transformer updated
        for k, p in zip(transformer.parameters(), acc_grads):
            k.grad = p
        # print(meta_gradient)

        # assign meta-gradient
        for i in range(len(model_outer.weights)):
            model_outer.weights[i].grad = meta_gradient[i] / args.tasks_per_metaupdate
            meta_gradient[i] = 0
        for j in range(len(model_outer.biases)):
            model_outer.biases[j].grad = meta_gradient[i + j + 1] / args.tasks_per_metaupdate
            meta_gradient[i + j + 1] = 0
        model_outer.task_context.grad = meta_gradient[i + j + 2] / args.tasks_per_metaupdate
        meta_gradient[i + j + 2] = 0

        # do update step on outer model
	
        meta_optimiser.step()
        opt_transformer.step()
        # ------------ logging ------------

        if i_iter % log_interval == 0:# and i_iter > 0:
            # evaluate on training set
            loss_mean, loss_conf = eval(args, copy.copy(model_outer), task_family=task_family_train,
                                        num_updates=args.num_inner_updates, transformer=transformer)
            logger.train_loss.append(loss_mean)
            logger.train_conf.append(loss_conf)

            # evaluate on test set
            loss_mean, loss_conf = eval(args, copy.copy(model_outer), task_family=task_family_valid,
                                        num_updates=args.num_inner_updates, transformer=transformer)
            logger.valid_loss.append(loss_mean)
            logger.valid_conf.append(loss_conf)

            # evaluate on validation set
            loss_mean, loss_conf = eval(args, copy.copy(model_outer), task_family=task_family_test,
                                        num_updates=args.num_inner_updates, transformer=transformer)
            logger.test_loss.append(loss_mean)
            logger.test_conf.append(loss_conf)

            # save logging results
            utils.save_obj(logger, path)

            # save best model
            if logger.valid_loss[-1] == np.min(logger.valid_loss):
                print('saving best model at iter', i_iter)
                logger.best_valid_model = copy.copy(model_outer)

            # visualise results
            if args.task == 'celeba':
                task_family_train.visualise(task_family_train, task_family_test, copy.copy(logger.best_valid_model),
                                       args, i_iter, transformer)

            # print current results
            logger.print_info(i_iter, start_time)
            start_time = time.time()

    return logger
Ejemplo n.º 7
0
    # keep track of best models
    logger.update_best_model(model, save_path)


if __name__ == '__main__':

    args = arguments.parse_args()

    # --- settings ---

    if not os.path.exists(os.path.join(utils.get_base_path(), 'result_files')):
        os.mkdir(os.path.join(utils.get_base_path(), 'result_files'))
    if not os.path.exists(os.path.join(utils.get_base_path(), 'result_plots')):
        os.mkdir(os.path.join(utils.get_base_path(), 'result_plots'))

    path = os.path.join(utils.get_base_path(), 'result_files', utils.get_path_from_args(args))
    log_interval = 100

    if (not os.path.exists(path + '.npy')) or args.rerun:
        print('Starting experiment. Logging under filename {}'.format(path + '.npy'))
        run(args, num_workers=1, log_interval=log_interval, save_path=path)
    else:
        print('Found results in {}. If you want to re-run, use the argument --rerun'.format(path))

    # -------------- plot -----------------

    plt.switch_backend('agg')
    training_stats, validation_stats = np.load(path + '.npy', allow_pickle=True)

    plt.figure(figsize=(10, 5))
    x_ticks = np.arange(1, log_interval * len(training_stats['train_accuracy_pre_update']), log_interval)
Ejemplo n.º 8
0
def run(args, log_interval=5000, rerun=False):
    assert not args.maml

    # see if we already ran this experiment
    code_root = os.path.dirname(os.path.realpath(__file__))
    if not os.path.isdir('{}/{}_result_files/'.format(code_root, args.task)):
        os.mkdir('{}/{}_result_files/'.format(code_root, args.task))
    path = '{}/{}_result_files/'.format(
        code_root, args.task) + utils.get_path_from_args(args)

    if os.path.exists(path + '.pkl') and not rerun:
        return utils.load_obj(path)

    start_time = time.time()
    utils.set_seed(args.seed)

    # --- initialise everything ---

    # get the task family
    task_family_train = multi()
    task_family_valid = multi()
    task_family_test = multi()

    L = get_l(5251, 1)

    # initialise network
    model = simple_MLP().to(args.device)

    # intitialise meta-optimiser
    # (only on shared params - context parameters are *not* registered parameters of the model)
    L_optimiser = optim.Adam([L], 0.001)

    # initialise loggers
    logger = Logger()
    logger.best_valid_model = copy.deepcopy(model)

    # --- main training loop ---

    for i_iter in range(args.n_iter):

        # sample tasks
        target_functions = task_family_train.sample_tasks(
            args.tasks_per_metaupdate)

        # --- inner loop ---
        meta_gradient = 0
        for t in range(args.tasks_per_metaupdate):

            # get data for current task
            train_inputs = task_family_train.sample_inputs(
                args.k_meta_train).to(args.device)

            #initialise st
            #s = get_s(args.n)
            # s_optimizer = optim.Adam([s], args.lr_s)

            new_params = L[:, 0].clone()

            for _ in range(args.num_inner_updates):
                # forward through model
                train_outputs = model(train_inputs, new_params)

                # get targets
                train_targets = target_functions[t](train_inputs)

                # ------------ update on current task ------------

                # compute loss for current task
                task_loss = F.mse_loss(train_outputs, train_targets)

                # compute gradient wrt context params
                task_gradients = \
                    torch.autograd.grad(task_loss, new_params, create_graph=not args.first_order)[0]

                # update context params (this will set up the computation graph correctly)
                new_params = new_params - args.lr_inner * task_gradients
                #print('l1',L.grad)
                # forward through model
                '''
                train_outputs = model(train_inputs, new_params)
                train_targets = target_functions[t](train_inputs)
                task_loss = F.mse_loss(train_outputs, train_targets)
                task_loss.backward()
                s_optimizer.step()
                L_optimizer.zero_grad()
                '''

            # ------------ compute meta-gradient on test loss of current task ------------

            # get test data
            test_inputs = task_family_train.sample_inputs(
                args.k_meta_test, args.use_ordered_pixels).to(args.device)

            # get outputs after update
            test_outputs = model(test_inputs, new_params)

            # get the correct targets
            test_targets = target_functions[t](test_inputs)

            # compute loss after updating context (will backprop through inner loop)
            loss_meta = F.mse_loss(test_outputs, test_targets)

            # compute gradient + save for current task
            task_grad = torch.autograd.grad(loss_meta, L)[0]

            #for i in range(len(task_grad)):
            # clip the gradient
            #   meta_gradient[i] += task_grad[i].detach().clamp_(-10, 10)
            meta_gradient += task_grad.detach().clamp_(-10, 10)

        # ------------ meta update ------------

        # assign meta-gradient
        L.grad = meta_gradient / args.tasks_per_metaupdate

        # do update step on shared model

        L_optimiser.step()
        L.grad = None

        # ------------ logging ------------

        if i_iter % log_interval == 0:

            # evaluate on training set
            loss_mean, loss_conf = eval_cavia(
                args,
                copy.deepcopy(model),
                L,
                task_family=task_family_train,
                num_updates=args.num_inner_updates)
            logger.train_loss.append(loss_mean)
            logger.train_conf.append(loss_conf)

            # evaluate on test set
            loss_mean, loss_conf = eval_cavia(
                args,
                copy.deepcopy(model),
                L,
                task_family=task_family_valid,
                num_updates=args.num_inner_updates)
            logger.valid_loss.append(loss_mean)
            logger.valid_conf.append(loss_conf)

            # evaluate on validation set
            loss_mean, loss_conf = eval_cavia(
                args,
                copy.deepcopy(model),
                L,
                task_family=task_family_test,
                num_updates=args.num_inner_updates)
            logger.test_loss.append(loss_mean)
            logger.test_conf.append(loss_conf)

            # save logging results
            utils.save_obj(logger, path)

            # save best model
            if logger.valid_loss[-1] == np.min(logger.valid_loss):
                print('saving best model at iter', i_iter)
                logger.best_valid_model = copy.deepcopy(L)

            # print current results
            logger.print_info(i_iter, start_time)
            start_time = time.time()

    return L
Ejemplo n.º 9
0
def run(args, log_interval=5000, rerun=False):
    assert not args.maml

    # see if we already ran this experiment
    code_root = os.path.dirname(os.path.realpath(__file__))
    if not os.path.isdir('{}/{}_result_files/'.format(code_root, args.task)):
        os.mkdir('{}/{}_result_files/'.format(code_root, args.task))
    path = '{}/{}_result_files/'.format(
        code_root, args.task) + utils.get_path_from_args(args)

    if os.path.exists(path + '.pkl') and not rerun:
        return utils.load_obj(path)

    start_time = time.time()
    utils.set_seed(args.seed)

    # --- initialise everything ---

    # get the task family
    if args.task == 'sine':
        task_family_train = tasks_sine.RegressionTasksSinusoidal()
        task_family_valid = tasks_sine.RegressionTasksSinusoidal()
        task_family_test = tasks_sine.RegressionTasksSinusoidal()
    elif args.task == 'celeba':
        task_family_train = tasks_celebA.CelebADataset('train',
                                                       device=args.device)
        task_family_valid = tasks_celebA.CelebADataset('valid',
                                                       device=args.device)
        task_family_test = tasks_celebA.CelebADataset('test',
                                                      device=args.device)
    else:
        raise NotImplementedError

    # initialise network
    model = CaviaModel(n_in=task_family_train.num_inputs,
                       n_out=task_family_train.num_outputs,
                       num_context_params=args.num_context_params,
                       n_hidden=args.num_hidden_layers,
                       device=args.device).to(args.device)

    # intitialise meta-optimiser
    # (only on shared params - context parameters are *not* registered parameters of the model)
    meta_optimiser = optim.Adam(model.parameters(), args.lr_meta)

    # initialise loggers
    logger = Logger()
    logger.best_valid_model = copy.deepcopy(model)

    # --- main training loop ---

    for i_iter in range(args.n_iter):

        # initialise meta-gradient
        meta_gradient = [0 for _ in range(len(model.state_dict()))]

        # sample tasks
        target_functions = task_family_train.sample_tasks(
            args.tasks_per_metaupdate)

        # --- inner loop ---

        for t in range(args.tasks_per_metaupdate):

            # reset private network weights
            model.reset_context_params()

            # get data for current task
            train_inputs = task_family_train.sample_inputs(
                args.k_meta_train, args.use_ordered_pixels).to(args.device)

            for _ in range(args.num_inner_updates):
                # forward through model
                train_outputs = model(train_inputs)

                # get targets
                train_targets = target_functions[t](train_inputs)

                # ------------ update on current task ------------

                # compute loss for current task
                task_loss = F.mse_loss(train_outputs, train_targets)

                # compute gradient wrt context params
                task_gradients = \
                    torch.autograd.grad(task_loss, model.context_params, create_graph=not args.first_order)[0]

                # update context params (this will set up the computation graph correctly)
                model.context_params = model.context_params - args.lr_inner * task_gradients

            # ------------ compute meta-gradient on test loss of current task ------------

            # get test data
            test_inputs = task_family_train.sample_inputs(
                args.k_meta_test, args.use_ordered_pixels).to(args.device)

            # get outputs after update
            test_outputs = model(test_inputs)

            # get the correct targets
            test_targets = target_functions[t](test_inputs)

            # compute loss after updating context (will backprop through inner loop)
            loss_meta = F.mse_loss(test_outputs, test_targets)

            # compute gradient + save for current task
            task_grad = torch.autograd.grad(loss_meta, model.parameters())

            for i in range(len(task_grad)):
                # clip the gradient
                meta_gradient[i] += task_grad[i].detach().clamp_(-10, 10)

        # ------------ meta update ------------

        # assign meta-gradient
        for i, param in enumerate(model.parameters()):
            param.grad = meta_gradient[i] / args.tasks_per_metaupdate

        # do update step on shared model
        meta_optimiser.step()

        # reset context params
        model.reset_context_params()

        # ------------ logging ------------

        if i_iter % log_interval == 0:

            # evaluate on training set
            loss_mean, loss_conf = eval_cavia(
                args,
                copy.deepcopy(model),
                task_family=task_family_train,
                num_updates=args.num_inner_updates)
            logger.train_loss.append(loss_mean)
            logger.train_conf.append(loss_conf)

            # evaluate on test set
            loss_mean, loss_conf = eval_cavia(
                args,
                copy.deepcopy(model),
                task_family=task_family_valid,
                num_updates=args.num_inner_updates)
            logger.valid_loss.append(loss_mean)
            logger.valid_conf.append(loss_conf)

            # evaluate on validation set
            loss_mean, loss_conf = eval_cavia(
                args,
                copy.deepcopy(model),
                task_family=task_family_test,
                num_updates=args.num_inner_updates)
            logger.test_loss.append(loss_mean)
            logger.test_conf.append(loss_conf)

            # save logging results
            utils.save_obj(logger, path)

            # save best model
            if logger.valid_loss[-1] == np.min(logger.valid_loss):
                print('saving best model at iter', i_iter)
                logger.best_valid_model = copy.deepcopy(model)

            # visualise results
            if args.task == 'celeba':
                task_family_train.visualise(
                    task_family_train, task_family_test,
                    copy.deepcopy(logger.best_valid_model), args, i_iter)

            # print current results
            logger.print_info(i_iter, start_time)
            start_time = time.time()

    return logger
Ejemplo n.º 10
0
def run(args, num_workers=1, log_interval=100, verbose=True, save_path=None):
    utils.set_seed(args.seed)

    # see if we already ran this experiment
    code_root = os.path.dirname(os.path.realpath(__file__))
    if not os.path.isdir('{}/{}_result_files/'.format(code_root,
                                                      'Mini-imagenet')):
        os.mkdir('{}/{}_result_files/'.format(code_root, 'Mini-imagenet'))
    save_path = '{}/{}_result_files/'.format(
        code_root, 'Mini-imagenet') + utils.get_path_from_args(args)

    if os.path.exists(save_path + '.pkl') and not rerun:
        return utils.load_obj(save_path)

    start_time = time.time()
    utils.set_seed(args.seed)

    # ---------------------------------------------------------
    # -------------------- training ---------------------------

    # initialise model
    model = CondConvNet(num_context_params=args.num_context_params,
                        context_in=args.context_in,
                        num_classes=args.n_way,
                        num_filters=args.num_filters,
                        max_pool=not args.no_max_pool,
                        num_film_hidden_layers=args.num_film_hidden_layers,
                        imsize=args.imsize,
                        initialisation=args.nn_initialisation,
                        device=args.device)
    model.train()

    # set up meta-optimiser for model parameters
    meta_optimiser = torch.optim.Adam(model.parameters(), args.lr_meta)
    scheduler = torch.optim.lr_scheduler.StepLR(meta_optimiser, 5000,
                                                args.lr_meta_decay)

    # initialise logger
    logger = Logger(log_interval, args, verbose=verbose)

    # initialise the starting point for the meta gradient (it's faster to copy this than to create new object)
    meta_grad_init = [0 for _ in range(len(model.state_dict()))]

    iter_counter = 0
    while iter_counter < args.n_iter:
        print(iter_counter)
        # batchsz here means total episode number
        dataset_train = MiniImagenet(mode='train',
                                     n_way=args.n_way,
                                     k_shot=args.k_shot,
                                     k_query=args.k_query,
                                     batchsz=10000,
                                     imsize=args.imsize,
                                     data_path=args.data_path)
        # fetch meta_batchsz num of episode each time
        dataloader_train = DataLoader(dataset_train,
                                      args.tasks_per_metaupdate,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      pin_memory=False)

        # initialise dataloader
        dataset_valid = MiniImagenet(mode='val',
                                     n_way=args.n_way,
                                     k_shot=args.k_shot,
                                     k_query=args.k_query,
                                     batchsz=500,
                                     imsize=args.imsize,
                                     data_path=args.data_path)
        dataloader_valid = DataLoader(dataset_valid,
                                      batch_size=num_workers,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      pin_memory=True)

        logger.print_header()

        for step, batch in enumerate(dataloader_train):

            scheduler.step()

            support_x = batch[0].to(args.device)
            support_y = batch[1].to(args.device)
            query_x = batch[2].to(args.device)
            query_y = batch[3].to(args.device)

            # skip batch if we don't have enough tasks in the current batch (might happen in last batch)
            if support_x.shape[0] != args.tasks_per_metaupdate:
                continue

            # initialise meta-gradient
            meta_grad = copy.deepcopy(meta_grad_init)

            logger.prepare_inner_loop(iter_counter)

            for inner_batch_idx in range(args.tasks_per_metaupdate):

                # reset context parameters
                model.reset_context_params()

                # -------------- inner update --------------

                logger.log_pre_update(iter_counter, support_x[inner_batch_idx],
                                      support_y[inner_batch_idx],
                                      query_x[inner_batch_idx],
                                      query_y[inner_batch_idx], model)

                for _ in range(args.num_grad_steps_inner):
                    # forward train data through net
                    pred_train = model(support_x[inner_batch_idx])

                    # compute loss
                    task_loss_train = F.cross_entropy(
                        pred_train, support_y[inner_batch_idx])

                    # compute gradient for context parameters
                    task_grad_train = torch.autograd.grad(task_loss_train,
                                                          model.context_params,
                                                          create_graph=True)[0]

                    # set context parameters to their updated values
                    model.context_params = model.context_params - args.lr_inner * task_grad_train

                # -------------- get meta gradient --------------

                # forward test data through updated net
                pred_test = model(query_x[inner_batch_idx])

                # compute loss on test data
                task_loss_test = F.cross_entropy(pred_test,
                                                 query_y[inner_batch_idx])

                # compute gradient for shared parameters
                task_grad_test = torch.autograd.grad(task_loss_test,
                                                     model.parameters())

                # add to meta-gradient
                for g in range(len(task_grad_test)):
                    meta_grad[g] += task_grad_test[g].detach()

                # ------------------------------------------------

                logger.log_post_update(iter_counter,
                                       support_x[inner_batch_idx],
                                       support_y[inner_batch_idx],
                                       query_x[inner_batch_idx],
                                       query_y[inner_batch_idx], model)

            # reset context parameters
            model.reset_context_params()

            # summarise inner loop and get validation performance
            logger.summarise_inner_loop(mode='train')

            if iter_counter % log_interval == 0:
                # evaluate how good the current model is (*before* updating so we can compare better)
                evaluate(iter_counter, args, model, logger, dataloader_valid,
                         save_path)
                if save_path is not None:
                    np.save(save_path,
                            [logger.training_stats, logger.validation_stats])
                    # save model to CPU
                    save_model = model
                    if args.device == 'cuda:0':
                        save_model = copy.deepcopy(model).to(
                            torch.args.device('cpu'))
                    torch.save(save_model, save_path)

            logger.print(iter_counter, task_grad_train, meta_grad)
            iter_counter += 1
            if iter_counter > args.n_iter:
                break

            # -------------- meta update --------------

            meta_optimiser.zero_grad()

            # set gradients of parameters manually
            for c, param in enumerate(model.parameters()):
                param.grad = meta_grad[c] / float(args.tasks_per_metaupdate)
                param.grad.data.clamp_(-10, 10)

            # the meta-optimiser only operates on the shared parameters, not the context parameters
            meta_optimiser.step()

    model.reset_context_params()
    return logger, model
Ejemplo n.º 11
0
def run(args, log_interval=5000, rerun=False):
    assert args.maml

    # see if we already ran this experiment
    code_root = os.path.dirname(os.path.realpath(__file__))
    if not os.path.isdir('{}/{}_result_files/'.format(code_root, args.task)):
        os.mkdir('{}/{}_result_files/'.format(code_root, args.task))
    path = '{}/{}_result_files/'.format(
        code_root, args.task) + utils.get_path_from_args(args)

    if os.path.exists(path + '.pkl') and not rerun:
        return utils.load_obj(path)

    start_time = time.time()

    # correctly seed everything
    utils.set_seed(args.seed)

    # --- initialise everything ---

    # get the task family
    if args.task == 'sine':
        task_family_train = tasks_sine.RegressionTasksSinusoidal()
        task_family_valid = tasks_sine.RegressionTasksSinusoidal()
        task_family_test = tasks_sine.RegressionTasksSinusoidal()
    elif args.task == 'celeba':
        task_family_train = tasks_celebA.CelebADataset('train')
        task_family_valid = tasks_celebA.CelebADataset('valid')
        task_family_test = tasks_celebA.CelebADataset('test')
    else:
        raise NotImplementedError

    # initialise network
    model_inner = MamlModel(task_family_train.num_inputs,
                            task_family_train.num_outputs,
                            n_weights=args.num_hidden_layers,
                            num_context_params=args.num_context_params,
                            device=args.device).to(args.device)
    model_outer = copy.deepcopy(model_inner)

    # intitialise meta-optimiser
    meta_optimiser = optim.Adam(
        model_outer.weights + model_outer.biases + [model_outer.task_context],
        args.lr_meta)

    # initialise loggers
    logger = Logger()
    logger.best_valid_model = copy.deepcopy(model_outer)

    for i_iter in range(args.n_iter):

        # copy weights of network
        copy_weights = [w.clone() for w in model_outer.weights]
        copy_biases = [b.clone() for b in model_outer.biases]
        copy_context = model_outer.task_context.clone()

        # get all shared parameters and initialise cumulative gradient
        meta_gradient = [0 for _ in range(len(copy_weights + copy_biases) + 1)]

        # sample tasks
        target_functions = task_family_train.sample_tasks(
            args.tasks_per_metaupdate)

        for t in range(args.tasks_per_metaupdate):

            # reset network weights
            model_inner.weights = [w.clone() for w in copy_weights]
            model_inner.biases = [b.clone() for b in copy_biases]
            model_inner.task_context = copy_context.clone()

            # get data for current task
            train_inputs = task_family_train.sample_inputs(
                args.k_meta_train, args.use_ordered_pixels).to(args.device)

            for _ in range(args.num_inner_updates):

                # forward through network
                outputs = model_outer(train_inputs)

                # get targets
                targets = target_functions[t](train_inputs)

                # ------------ update on current task ------------

                # compute loss for current task
                loss_task = F.mse_loss(outputs, targets)

                # update private parts of network and keep correct computation graph
                params = [w for w in model_outer.weights] + [
                    b for b in model_outer.biases
                ] + [model_outer.task_context]
                grads = torch.autograd.grad(loss_task,
                                            params,
                                            create_graph=True,
                                            retain_graph=True)
                for i in range(len(model_inner.weights)):
                    if not args.first_order:
                        model_inner.weights[i] = model_outer.weights[
                            i] - args.lr_inner * grads[i]
                    else:
                        model_inner.weights[i] = model_outer.weights[
                            i] - args.lr_inner * grads[i].detach()
                for j in range(len(model_inner.biases)):
                    if not args.first_order:
                        model_inner.biases[j] = model_outer.biases[
                            j] - args.lr_inner * grads[i + j + 1]
                    else:
                        model_inner.biases[j] = model_outer.biases[
                            j] - args.lr_inner * grads[i + j + 1].detach()
                if not args.first_order:
                    model_inner.task_context = model_outer.task_context - args.lr_inner * grads[
                        i + j + 2]
                else:
                    model_inner.task_context = model_outer.task_context - args.lr_inner * grads[
                        i + j + 2].detach()

            # ------------ compute meta-gradient on test loss of current task ------------

            # get test data
            test_inputs = task_family_train.sample_inputs(
                args.k_meta_test, args.use_ordered_pixels).to(args.device)

            # get outputs after update
            test_outputs = model_inner(test_inputs)

            # get the correct targets
            test_targets = target_functions[t](test_inputs)

            # compute loss (will backprop through inner loop)
            loss_meta = F.mse_loss(test_outputs, test_targets)

            # compute gradient w.r.t. *outer model*
            task_grads = torch.autograd.grad(
                loss_meta, model_outer.weights + model_outer.biases +
                [model_outer.task_context])
            for i in range(len(model_inner.weights + model_inner.biases) + 1):
                meta_gradient[i] += task_grads[i].detach()

        # ------------ meta update ------------

        meta_optimiser.zero_grad()
        # print(meta_gradient)

        # assign meta-gradient
        for i in range(len(model_outer.weights)):
            model_outer.weights[
                i].grad = meta_gradient[i] / args.tasks_per_metaupdate
            meta_gradient[i] = 0
        for j in range(len(model_outer.biases)):
            model_outer.biases[j].grad = meta_gradient[
                i + j + 1] / args.tasks_per_metaupdate
            meta_gradient[i + j + 1] = 0
        model_outer.task_context.grad = meta_gradient[
            i + j + 2] / args.tasks_per_metaupdate
        meta_gradient[i + j + 2] = 0

        # do update step on outer model
        meta_optimiser.step()

        # ------------ logging ------------

        if i_iter % log_interval == 0:

            # evaluate on training set
            loss_mean, loss_conf = eval(args,
                                        copy.deepcopy(model_outer),
                                        task_family=task_family_train,
                                        num_updates=args.num_inner_updates)
            logger.train_loss.append(loss_mean)
            logger.train_conf.append(loss_conf)

            # evaluate on test set
            loss_mean, loss_conf = eval(args,
                                        copy.deepcopy(model_outer),
                                        task_family=task_family_valid,
                                        num_updates=args.num_inner_updates)
            logger.valid_loss.append(loss_mean)
            logger.valid_conf.append(loss_conf)

            # evaluate on validation set
            loss_mean, loss_conf = eval(args,
                                        copy.deepcopy(model_outer),
                                        task_family=task_family_test,
                                        num_updates=args.num_inner_updates)
            logger.test_loss.append(loss_mean)
            logger.test_conf.append(loss_conf)

            # save logging results
            utils.save_obj(logger, path)

            # save best model
            if logger.valid_loss[-1] == np.min(logger.valid_loss):
                print('saving best model at iter', i_iter)
                logger.best_valid_model = copy.deepcopy(model_outer)

            # visualise results
            if args.task == 'celeba':
                tasks_celebA.visualise(task_family_train, task_family_test,
                                       copy.deepcopy(logger.best_valid_model),
                                       args, i_iter)

            # print current results
            logger.print_info(i_iter, start_time)
            start_time = time.time()

    return logger
Ejemplo n.º 12
0
    # --- settings ---

    args.k_shot = 1
    args.lr_inner = 1.0
    args.lr_meta = 'decay'
    args.num_grad_steps_inner = 2
    args.num_grad_steps_eval = 2
    args.model = 'cnn'
    args.num_context_params = 100

    if args.k_shot == 1:
        args.tasks_per_metaupdate = 4
    else:
        args.tasks_per_metaupdate = 2

    path = os.path.join(utils.get_base_path(), 'result_files', datetime_folder, utils.get_path_from_args(args))
    try:
        training_stats, validation_stats = np.load(path + '.npy')
    except FileNotFoundError:
        print('You need to run the experiments first and make sure the results are saved at {}'.format(path))
        raise FileNotFoundError

    # initialise logger
    logger = Logger(args)
    logger.print_header()

    for num_grad_steps in [2]:

        print('\n --- ', num_grad_steps, '--- \n')

        # initialise logger