Ejemplo n.º 1
0
def test(args, n_tasks=5000):
    # copy weights of network
    logger = utils.load_obj('./multi_result_files/9bb41b077266b414060f375af037ded9')
    model = logger.best_valid_model
    utils.set_seed(args.seed*2)
    task_family_test = multi()
    loss_mean, loss_conf = eval(args, copy.copy(model), task_family=task_family_test,
                                        num_updates=args.num_inner_updates,n_tasks=10000)
    print(loss_mean, loss_conf)
Ejemplo n.º 2
0
def test(args, n_tasks=5000):
    # copy weights of network
    logger = utils.load_obj('./multi_result_files/5123eae98eedec114f92c97a98de5824')
    model = logger.best_valid_model
    utils.set_seed(args.seed*2)
    task_family_test = multi()
    loss_mean, loss_conf = eval_cavia(args, copy.copy(model), task_family=task_family_test,
                                        num_updates=args.num_inner_updates, n_tasks=n_tasks)
    print(loss_mean, loss_conf)
Ejemplo n.º 3
0
def test(args, n_tasks=5000):
    # copy weights of network
    global temp
    temp = 0.3
    logger = utils.load_obj(
        './multi_result_files/39d86a43969d0773fe4a6dca39446053')
    model = logger.best_valid_model
    model.num_context_params = args.num_context_params
    encoder = logger.best_encoder_valid_model
    p_encoder = logger.best_place_valid_model
    utils.set_seed(args.seed * 2)
    task_family_test = multi()
    loss_mean, loss_conf = eval_cavia(args, copy.copy(model), task_family=task_family_test,
                                        num_updates=args.num_inner_updates, n_tasks=n_tasks,encoder = \
                                      encoder,p_encoder=p_encoder)
    print(loss_mean, loss_conf)
Ejemplo n.º 4
0
def test(args, n_tasks=5000):
    # copy weights of network
    global temp
    temp = 0.5
    logger = utils.load_obj(
        './multi_result_files/22d766b4a3616fc920e376a9e6a5d5ee')
    model = logger.best_valid_model
    model.num_context_params = args.num_context_params
    encoder = logger.best_encoder_valid_model
    gradient_place = logger.best_place_valid_model
    utils.set_seed(args.seed * 2)
    task_family_test = multi()
    loss_mean, loss_conf = eval_cavia(args,
                                      copy.copy(model),
                                      task_family=task_family_test,
                                      num_updates=5,
                                      n_tasks=n_tasks,
                                      encoder=encoder,
                                      gradient_place=gradient_place)
    print(loss_mean, loss_conf)
Ejemplo n.º 5
0
def run(args, log_interval=5000, rerun=False):
    global temp
    assert not args.maml
    # see if we already ran this experiment
    code_root = os.path.dirname(os.path.realpath(__file__))
    if not os.path.isdir('{}/{}_result_files/'.format(code_root, args.task)):
        os.mkdir('{}/{}_result_files/'.format(code_root, args.task))
    path = '{}/{}_result_files/'.format(
        code_root, args.task) + utils.get_path_from_args(args)

    if os.path.exists(path + '.pkl') and not rerun:
        return utils.load_obj(path)

    start_time = time.time()
    utils.set_seed(args.seed)

    # --- initialise everything ---

    # get the task family
    if args.task == 'sine':
        task_family_train = tasks_sine.RegressionTasksSinusoidal()
        task_family_valid = tasks_sine.RegressionTasksSinusoidal()
        task_family_test = tasks_sine.RegressionTasksSinusoidal()
    elif args.task == 'celeba':
        task_family_train = tasks_celebA.CelebADataset('train',
                                                       device=args.device)
        task_family_valid = tasks_celebA.CelebADataset('valid',
                                                       device=args.device)
        task_family_test = tasks_celebA.CelebADataset('test',
                                                      device=args.device)
    elif args.task == 'multi':
        task_family_train = multi()
        task_family_valid = multi()
        task_family_test = multi()
    else:
        raise NotImplementedError

    # initialise network
    model = CaviaModel(n_in=task_family_train.num_inputs,
                       n_out=task_family_train.num_outputs,
                       num_context_params=args.num_context_params,
                       n_hidden=args.num_hidden_layers,
                       device=args.device).to(args.device)
    # intitialise meta-optimiser
    # (only on shared params - context parameters are *not* registered parameters of the model)
    meta_optimiser = optim.Adam(model.parameters(), args.lr_meta)
    encoder = pool_encoder().to(args.device)
    encoder_optimiser = optim.Adam(encoder.parameters(), lr=1e-3)
    decoder = pool_decoder().to(args.device)
    decoder_optimiser = optim.Adam(decoder.parameters(), lr=1e-3)
    #encoder.load_state_dict(torch.load('./model/encoder'))
    p_encoder = place().to(args.device)
    p_optimiser = optim.Adam(p_encoder.parameters(), lr=1e-3)
    # initialise loggers
    logger = Logger()
    logger.best_valid_model = copy.deepcopy(model)

    # --- main training loop ---

    for i_iter in range(args.n_iter):
        # initialise meta-gradient
        meta_gradient = [0 for _ in range(len(model.state_dict()))]
        place_gradient = [0 for _ in range(len(p_encoder.state_dict()))]
        encoder_gradient = [0 for _ in range(len(encoder.state_dict()))]
        #print(meta_gradient)

        # sample tasks
        target_functions, ty = task_family_train.sample_tasks(
            args.tasks_per_metaupdate, True)

        # --- inner loop ---

        for t in range(args.tasks_per_metaupdate):

            # reset private network weights
            model.reset_context_params()

            # get data for current task
            x = task_family_train.sample_inputs(
                args.k_meta_train, args.use_ordered_pixels).to(args.device)

            y = target_functions[t](x)
            train_inputs = torch.cat([x, y], dim=1)
            a = encoder(train_inputs)
            #embedding,_ = torch.max(a,dim=0)
            embedding = torch.mean(a, dim=0)

            logits = p_encoder(embedding)
            logits = logits.reshape([latent_dim, categorical_dim])

            y = gumbel_softmax(logits, temp, hard=True)
            y = y[:, 1]
            #print(temp)

            #model.set_context_params(embedding)
            #print(model.context_params)

            for _ in range(args.num_inner_updates):
                # forward through model
                train_outputs = model(x)

                # get targets
                train_targets = target_functions[t](x)

                # ------------ update on current task ------------

                # compute loss for current task
                task_loss = F.mse_loss(train_outputs, train_targets)

                # compute gradient wrt context params
                task_gradients = \
                    torch.autograd.grad(task_loss, model.context_params, create_graph=not args.first_order)[0]

                # update context params (this will set up the computation graph correctly)
                model.context_params = model.context_params - args.lr_inner * task_gradients * y

            #print(model.context_params)
            # ------------ compute meta-gradient on test loss of current task ------------

            # get test data
            test_inputs = task_family_train.sample_inputs(
                args.k_meta_test, args.use_ordered_pixels).to(args.device)

            # get outputs after update
            test_outputs = model(test_inputs)

            # get the correct targets
            test_targets = target_functions[t](test_inputs)

            # compute loss after updating context (will backprop through inner loop)
            loss_meta = F.mse_loss(test_outputs, test_targets)
            #print(torch.norm(y,1)/1000)
            #loss_meta += torch.norm(y,1)/700
            qy = F.softmax(logits, dim=-1)
            log_ratio = torch.log(qy * categorical_dim + 1e-20)
            KLD = torch.sum(qy * log_ratio, dim=-1).mean() / 5
            # print(KLD)
            loss_meta += KLD

            # compute gradient + save for current task
            task_grad = torch.autograd.grad(loss_meta,
                                            model.parameters(),
                                            retain_graph=True)

            for i in range(len(task_grad)):
                # clip the gradient
                meta_gradient[i] += task_grad[i].detach().clamp_(-10, 10)

            task_grad_place = torch.autograd.grad(loss_meta,
                                                  p_encoder.parameters(),
                                                  retain_graph=True)

            for i in range(len(task_grad_place)):
                # clip the gradient
                place_gradient[i] += task_grad_place[i].detach().clamp_(
                    -10, 10)

            task_grad_encoder = torch.autograd.grad(loss_meta,
                                                    encoder.parameters())
            for i in range(len(task_grad_encoder)):
                # clip the gradient
                encoder_gradient[i] += task_grad_encoder[i].detach().clamp_(
                    -10, 10)

        # ------------ meta update ------------

        # assign meta-gradient
        for i, param in enumerate(model.parameters()):
            param.grad = meta_gradient[i] / args.tasks_per_metaupdate
        meta_optimiser.step()

        # do update step on shared model
        for i, param in enumerate(p_encoder.parameters()):
            param.grad = place_gradient[i] / args.tasks_per_metaupdate
        p_optimiser.step()

        for i, param in enumerate(encoder.parameters()):
            param.grad = encoder_gradient[i] / args.tasks_per_metaupdate
        encoder_optimiser.step()

        # reset context params
        model.reset_context_params()

        if i_iter % 350 == 1:
            temp = np.maximum(temp * np.exp(-ANNEAL_RATE * i_iter), 0.5)
            print(temp)
        # ------------ logging ------------

        if i_iter % log_interval == 0:

            # evaluate on training set
            loss_mean, loss_conf = eval_cavia(
                args,
                copy.deepcopy(model),
                task_family=task_family_train,
                num_updates=args.num_inner_updates,
                encoder=encoder,
                p_encoder=p_encoder)
            logger.train_loss.append(loss_mean)
            logger.train_conf.append(loss_conf)

            # evaluate on test set
            loss_mean, loss_conf = eval_cavia(
                args,
                copy.deepcopy(model),
                task_family=task_family_valid,
                num_updates=args.num_inner_updates,
                encoder=encoder,
                p_encoder=p_encoder)
            logger.valid_loss.append(loss_mean)
            logger.valid_conf.append(loss_conf)

            # evaluate on validation set

            if i_iter % log_interval == 0:
                loss_mean, loss_conf = eval_cavia(
                    args,
                    copy.deepcopy(model),
                    task_family=task_family_test,
                    num_updates=args.num_inner_updates,
                    encoder=encoder,
                    p_encoder=p_encoder)
                logger.test_loss.append(loss_mean)
                logger.test_conf.append(loss_conf)

            # save logging results
            utils.save_obj(logger, path)

            # save best model
            if logger.valid_loss[-1] == np.min(logger.valid_loss):
                print('saving best model at iter', i_iter)
                logger.best_valid_model = copy.deepcopy(model)
                logger.best_encoder_valid_model = copy.deepcopy(encoder)
                logger.best_place_valid_model = copy.deepcopy(p_encoder)

            if i_iter % (4 * log_interval) == 0:
                print('saving model at iter', i_iter)
                logger.valid_model.append(copy.deepcopy(model))
                logger.encoder_valid_model.append(copy.deepcopy(encoder))
                logger.place_valid_model.append(copy.deepcopy(p_encoder))

            # visualise results
            if args.task == 'celeba':
                task_family_train.visualise(
                    task_family_train, task_family_test,
                    copy.deepcopy(logger.best_valid_model), args, i_iter)

            # print current results
            logger.print_info(i_iter, start_time)
            start_time = time.time()

    return logger
Ejemplo n.º 6
0
def run(args, log_interval=5000, rerun=False):
    assert not args.maml

    # see if we already ran this experiment
    code_root = os.path.dirname(os.path.realpath(__file__))
    if not os.path.isdir('{}/{}_result_files/'.format(code_root, args.task)):
        os.mkdir('{}/{}_result_files/'.format(code_root, args.task))
    path = '{}/{}_result_files/'.format(
        code_root, args.task) + utils.get_path_from_args(args)

    if os.path.exists(path + '.pkl') and not rerun:
        return utils.load_obj(path)

    start_time = time.time()
    utils.set_seed(args.seed)

    # --- initialise everything ---

    # get the task family
    task_family_train = multi()
    task_family_valid = multi()
    task_family_test = multi()

    L = get_l(5251, 1)

    # initialise network
    model = simple_MLP().to(args.device)

    # intitialise meta-optimiser
    # (only on shared params - context parameters are *not* registered parameters of the model)
    L_optimiser = optim.Adam([L], 0.001)

    # initialise loggers
    logger = Logger()
    logger.best_valid_model = copy.deepcopy(model)

    # --- main training loop ---

    for i_iter in range(args.n_iter):

        # sample tasks
        target_functions = task_family_train.sample_tasks(
            args.tasks_per_metaupdate)

        # --- inner loop ---
        meta_gradient = 0
        for t in range(args.tasks_per_metaupdate):

            # get data for current task
            train_inputs = task_family_train.sample_inputs(
                args.k_meta_train).to(args.device)

            #initialise st
            #s = get_s(args.n)
            # s_optimizer = optim.Adam([s], args.lr_s)

            new_params = L[:, 0].clone()

            for _ in range(args.num_inner_updates):
                # forward through model
                train_outputs = model(train_inputs, new_params)

                # get targets
                train_targets = target_functions[t](train_inputs)

                # ------------ update on current task ------------

                # compute loss for current task
                task_loss = F.mse_loss(train_outputs, train_targets)

                # compute gradient wrt context params
                task_gradients = \
                    torch.autograd.grad(task_loss, new_params, create_graph=not args.first_order)[0]

                # update context params (this will set up the computation graph correctly)
                new_params = new_params - args.lr_inner * task_gradients
                #print('l1',L.grad)
                # forward through model
                '''
                train_outputs = model(train_inputs, new_params)
                train_targets = target_functions[t](train_inputs)
                task_loss = F.mse_loss(train_outputs, train_targets)
                task_loss.backward()
                s_optimizer.step()
                L_optimizer.zero_grad()
                '''

            # ------------ compute meta-gradient on test loss of current task ------------

            # get test data
            test_inputs = task_family_train.sample_inputs(
                args.k_meta_test, args.use_ordered_pixels).to(args.device)

            # get outputs after update
            test_outputs = model(test_inputs, new_params)

            # get the correct targets
            test_targets = target_functions[t](test_inputs)

            # compute loss after updating context (will backprop through inner loop)
            loss_meta = F.mse_loss(test_outputs, test_targets)

            # compute gradient + save for current task
            task_grad = torch.autograd.grad(loss_meta, L)[0]

            #for i in range(len(task_grad)):
            # clip the gradient
            #   meta_gradient[i] += task_grad[i].detach().clamp_(-10, 10)
            meta_gradient += task_grad.detach().clamp_(-10, 10)

        # ------------ meta update ------------

        # assign meta-gradient
        L.grad = meta_gradient / args.tasks_per_metaupdate

        # do update step on shared model

        L_optimiser.step()
        L.grad = None

        # ------------ logging ------------

        if i_iter % log_interval == 0:

            # evaluate on training set
            loss_mean, loss_conf = eval_cavia(
                args,
                copy.deepcopy(model),
                L,
                task_family=task_family_train,
                num_updates=args.num_inner_updates)
            logger.train_loss.append(loss_mean)
            logger.train_conf.append(loss_conf)

            # evaluate on test set
            loss_mean, loss_conf = eval_cavia(
                args,
                copy.deepcopy(model),
                L,
                task_family=task_family_valid,
                num_updates=args.num_inner_updates)
            logger.valid_loss.append(loss_mean)
            logger.valid_conf.append(loss_conf)

            # evaluate on validation set
            loss_mean, loss_conf = eval_cavia(
                args,
                copy.deepcopy(model),
                L,
                task_family=task_family_test,
                num_updates=args.num_inner_updates)
            logger.test_loss.append(loss_mean)
            logger.test_conf.append(loss_conf)

            # save logging results
            utils.save_obj(logger, path)

            # save best model
            if logger.valid_loss[-1] == np.min(logger.valid_loss):
                print('saving best model at iter', i_iter)
                logger.best_valid_model = copy.deepcopy(L)

            # print current results
            logger.print_info(i_iter, start_time)
            start_time = time.time()

    return L
Ejemplo n.º 7
0
def run(args, log_interval=5000, rerun=False):
    assert not args.maml

    # see if we already ran this experiment
    code_root = os.path.dirname(os.path.realpath(__file__))
    if not os.path.isdir('{}/{}_result_files/'.format(code_root, args.task)):
        os.mkdir('{}/{}_result_files/'.format(code_root, args.task))
    path = '{}/{}_result_files/'.format(code_root, args.task) + utils.get_path_from_args(args)

    if os.path.exists(path + '.pkl') and not rerun:
        return utils.load_obj(path)

    start_time = time.time()
    utils.set_seed(args.seed)

    # --- initialise everything ---

    # get the task family
    if args.task == 'sine':
        task_family_train = tasks_sine.RegressionTasksSinusoidal()
        task_family_valid = tasks_sine.RegressionTasksSinusoidal()
        task_family_test = tasks_sine.RegressionTasksSinusoidal()
    elif args.task == 'celeba':
        task_family_train = tasks_celebA.CelebADataset('train', device=args.device)
        task_family_valid = tasks_celebA.CelebADataset('valid', device=args.device)
        task_family_test = tasks_celebA.CelebADataset('test', device=args.device)
    elif args.task == 'multi':
        task_family_train = multi()
        task_family_valid = multi()
        task_family_test = multi()
    else:
        raise NotImplementedError

    # initialise network
    model = CaviaModel(n_in=task_family_train.num_inputs,
                       n_out=task_family_train.num_outputs,
                       num_context_params=args.num_context_params,
                       n_hidden=args.num_hidden_layers,
                       device=args.device
                       ).to(args.device)

    # intitialise meta-optimiser
    # (only on shared params - context parameters are *not* registered parameters of the model)
    meta_optimiser = optim.Adam(model.parameters(), args.lr_meta)

    # initialise loggers
    logger = Logger()
    logger.best_valid_model = copy.deepcopy(model)

    # --- main training loop ---

    for i_iter in range(args.n_iter):

        # initialise meta-gradient
        meta_gradient = [0 for _ in range(len(model.state_dict()))]

        # sample tasks
        target_functions = task_family_train.sample_tasks(args.tasks_per_metaupdate)

        # --- inner loop ---

        for t in range(args.tasks_per_metaupdate):

            # reset private network weights
            model.reset_context_params()

            # get data for current task
            train_inputs = task_family_train.sample_inputs(args.k_meta_train, args.use_ordered_pixels).to(args.device)

            for _ in range(args.num_inner_updates):
                # forward through model
                train_outputs = model(train_inputs)

                # get targets
                train_targets = target_functions[t](train_inputs)

                # ------------ update on current task ------------

                # compute loss for current task
                task_loss = F.mse_loss(train_outputs, train_targets)

                # compute gradient wrt context params
                task_gradients = \
                    torch.autograd.grad(task_loss, model.context_params, create_graph=not args.first_order)[0]

                # update context params (this will set up the computation graph correctly)
                model.context_params = model.context_params - args.lr_inner * task_gradients

            # ------------ compute meta-gradient on test loss of current task ------------

            # get test data
            test_inputs = task_family_train.sample_inputs(args.k_meta_test, args.use_ordered_pixels).to(args.device)

            # get outputs after update
            test_outputs = model(test_inputs)

            # get the correct targets
            test_targets = target_functions[t](test_inputs)

            # compute loss after updating context (will backprop through inner loop)
            loss_meta = F.mse_loss(test_outputs, test_targets)

            # compute gradient + save for current task
            task_grad = torch.autograd.grad(loss_meta, model.parameters())

            for i in range(len(task_grad)):
                # clip the gradient
                meta_gradient[i] += task_grad[i].detach().clamp_(-10, 10)

        # ------------ meta update ------------

        # assign meta-gradient
        for i, param in enumerate(model.parameters()):
            param.grad = meta_gradient[i] / args.tasks_per_metaupdate

        # do update step on shared model
        meta_optimiser.step()

        # reset context params
        model.reset_context_params()

        # ------------ logging ------------

        if i_iter % log_interval == 0:

            # evaluate on training set
            loss_mean, loss_conf = eval_cavia(args, copy.deepcopy(model), task_family=task_family_train,
                                              num_updates=args.num_inner_updates)
            logger.train_loss.append(loss_mean)
            logger.train_conf.append(loss_conf)

            # evaluate on test set
            loss_mean, loss_conf = eval_cavia(args, copy.deepcopy(model), task_family=task_family_valid,
                                              num_updates=args.num_inner_updates)
            logger.valid_loss.append(loss_mean)
            logger.valid_conf.append(loss_conf)

            # evaluate on validation set
            loss_mean, loss_conf = eval_cavia(args, copy.deepcopy(model), task_family=task_family_test,
                                              num_updates=args.num_inner_updates)
            logger.test_loss.append(loss_mean)
            logger.test_conf.append(loss_conf)

            # save logging results
            utils.save_obj(logger, path)

            # save best model
            if logger.valid_loss[-1] == np.min(logger.valid_loss):
                print('saving best model at iter', i_iter)
                logger.best_valid_model = copy.deepcopy(model)

            # visualise results
            if args.task == 'celeba':
                task_family_train.visualise(task_family_train, task_family_test, copy.deepcopy(logger.best_valid_model),
                                            args, i_iter)

            # print current results
            logger.print_info(i_iter, start_time)
            start_time = time.time()

    return logger
Ejemplo n.º 8
0
def run(args, log_interval=5000, rerun=False):
    assert args.maml

    # see if we already ran this experiment
    code_root = os.path.dirname(os.path.realpath(__file__))
    if not os.path.isdir('{}/{}_result_files/'.format(code_root, args.task)):
        os.mkdir('{}/{}_result_files/'.format(code_root, args.task))
    path = '{}/{}_result_files/'.format(code_root, args.task) + utils.get_path_from_args(args)

    if os.path.exists(path + '.pkl') and not rerun:
        return utils.load_obj(path)

    start_time = time.time()

    # correctly seed everything
    utils.set_seed(args.seed)

    # --- initialise everything ---

    # get the task family
    if args.task == 'sine':
        task_family_train = tasks_sine.RegressionTasksSinusoidal()
        task_family_valid = tasks_sine.RegressionTasksSinusoidal()
        task_family_test = tasks_sine.RegressionTasksSinusoidal()
    elif args.task == 'celeba':
        task_family_train = tasks_celebA.CelebADataset('train', args.device)
        task_family_valid = tasks_celebA.CelebADataset('valid', args.device)
        task_family_test = tasks_celebA.CelebADataset('test', args.device)
    else:
        task_family_train = multi()
        task_family_valid = multi()
        task_family_test = multi()
        #raise NotImplementedError

    # initialise network
    model_inner = MamlModel(task_family_train.num_inputs,
                            task_family_train.num_outputs,
                            n_weights=args.num_hidden_layers,
                            num_context_params=args.num_context_params,
                            device=args.device
                            ).to(args.device)
    model_outer = copy.deepcopy(model_inner)

    # intitialise meta-optimiser
    meta_optimiser = optim.Adam(model_outer.weights + model_outer.biases + [model_outer.task_context],
                                args.lr_meta)

    # initialise loggers
    logger = Logger()
    logger.best_valid_model = copy.deepcopy(model_outer)

    for i_iter in range(args.n_iter):

        # copy weights of network
        copy_weights = [w.clone() for w in model_outer.weights]
        copy_biases = [b.clone() for b in model_outer.biases]
        copy_context = model_outer.task_context.clone()

        # get all shared parameters and initialise cumulative gradient
        meta_gradient = [0 for _ in range(len(copy_weights + copy_biases) + 1)]

        # sample tasks
        target_functions = task_family_train.sample_tasks(args.tasks_per_metaupdate)

        for t in range(args.tasks_per_metaupdate):

            # reset network weights
            model_inner.weights = [w.clone() for w in copy_weights]
            model_inner.biases = [b.clone() for b in copy_biases]
            model_inner.task_context = copy_context.clone()

            # get data for current task
            train_inputs = task_family_train.sample_inputs(args.k_meta_train, args.use_ordered_pixels).to(args.device)

            for _ in range(args.num_inner_updates):

                # make prediction using the current model
                outputs = model_inner(train_inputs)

                # get targets
                targets = target_functions[t](train_inputs)

                # ------------ update on current task ------------

                # compute loss for current task
                loss_task = F.mse_loss(outputs, targets)

                # compute the gradient wrt current model
                params = [w for w in model_inner.weights] + [b for b in model_inner.biases] + [model_inner.task_context]
                grads = torch.autograd.grad(loss_task, params, create_graph=True, retain_graph=True)

                # make an update on the inner model using the current model (to build up computation graph)
                model_inner.biases[0] = model_inner.biases[0] - args.lr_inner * grads[len(model_inner.weights)].detach()
                #model_inner.weights[1] = model_inner.weights[1] - args.lr_inner * grads[1].detach()
                '''
                for i in range(len(model_inner.weights)):
                    if not args.first_order:
                        model_inner.weights[i] = model_inner.weights[i] - args.lr_inner * grads[i]
                    else:
                        model_inner.weights[i] = model_inner.weights[i] - args.lr_inner * grads[i].detach()
                for j in range(len(model_inner.biases)):
                    if not args.first_order:
                        model_inner.biases[j] = model_inner.biases[j] - args.lr_inner * grads[i + j + 1]
                    else:
                        model_inner.biases[j] = model_inner.biases[j] - args.lr_inner * grads[i + j + 1].detach()
                if not args.first_order:
                    model_inner.task_context = model_inner.task_context - args.lr_inner * grads[i + j + 2]
                else:
                    model_inner.task_context = model_inner.task_context - args.lr_inner * grads[i + j + 2].detach()
                '''
            # ------------ compute meta-gradient on test loss of current task ------------

            # get test data
            test_inputs = task_family_train.sample_inputs(args.k_meta_test, args.use_ordered_pixels).to(args.device)

            # get outputs after update
            test_outputs = model_inner(test_inputs)

            # get the correct targets
            test_targets = target_functions[t](test_inputs)

            # compute loss (will backprop through inner loop)
            loss_meta = F.mse_loss(test_outputs, test_targets)

            # compute gradient w.r.t. *outer model*
            task_grads = torch.autograd.grad(loss_meta,
                                             model_outer.weights + model_outer.biases + [model_outer.task_context])
            for i in range(len(model_inner.weights + model_inner.biases) + 1):
                meta_gradient[i] += task_grads[i].detach()

        # ------------ meta update ------------

        meta_optimiser.zero_grad()
        # print(meta_gradient)

        # assign meta-gradient
        for i in range(len(model_outer.weights)):
            model_outer.weights[i].grad = meta_gradient[i] / args.tasks_per_metaupdate
            meta_gradient[i] = 0
        for j in range(len(model_outer.biases)):
            model_outer.biases[j].grad = meta_gradient[i + j + 1] / args.tasks_per_metaupdate
            meta_gradient[i + j + 1] = 0
        model_outer.task_context.grad = meta_gradient[i + j + 2] / args.tasks_per_metaupdate
        meta_gradient[i + j + 2] = 0

        # do update step on outer model
        meta_optimiser.step()

        # ------------ logging ------------

        if i_iter % log_interval == 0:

            # evaluate on training set
            loss_mean, loss_conf = eval(args, copy.copy(model_outer), task_family=task_family_train,
                                        num_updates=args.num_inner_updates)
            logger.train_loss.append(loss_mean)
            logger.train_conf.append(loss_conf)

            # evaluate on test set
            loss_mean, loss_conf = eval(args, copy.copy(model_outer), task_family=task_family_valid,
                                        num_updates=args.num_inner_updates)
            logger.valid_loss.append(loss_mean)
            logger.valid_conf.append(loss_conf)

            # evaluate on validation set
            loss_mean, loss_conf = eval(args, copy.copy(model_outer), task_family=task_family_test,
                                        num_updates=args.num_inner_updates)
            logger.test_loss.append(loss_mean)
            logger.test_conf.append(loss_conf)

            # save logging results
            utils.save_obj(logger, path)

            # save best model
            if logger.valid_loss[-1] == np.min(logger.valid_loss):
                print('saving best model at iter', i_iter)
                logger.best_valid_model = copy.copy(model_outer)

            # visualise results
            if args.task == 'celeba':
                task_family_train.visualise(task_family_train, task_family_test, copy.copy(logger.best_valid_model),
                                       args, i_iter)

            # print current results
            logger.print_info(i_iter, start_time)
            start_time = time.time()

    return logger