def outer_loop(self, batch, is_train):

        train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch(
            batch)

        loss_log = 0
        acc_log = 0
        grad_list = []
        loss_list = []

        for (train_input, train_target, test_input,
             test_target) in zip(train_inputs, train_targets, test_inputs,
                                 test_targets):

            with higher.innerloop_ctx(self.network,
                                      self.inner_optimizer,
                                      track_higher_grads=False) as (fmodel,
                                                                    diffopt):

                for step in range(self.args.n_inner):
                    self.inner_loop(fmodel, diffopt, train_input, train_target)

                train_logit = fmodel(train_input)
                in_loss = F.cross_entropy(train_logit, train_target)

                test_logit = fmodel(test_input)
                outer_loss = F.cross_entropy(test_logit, test_target)
                loss_log += outer_loss.item() / self.batch_size

                with torch.no_grad():
                    acc_log += get_accuracy(
                        test_logit, test_target).item() / self.batch_size

                if is_train:
                    params = list(fmodel.parameters(time=-1))
                    in_grad = torch.nn.utils.parameters_to_vector(
                        torch.autograd.grad(in_loss, params,
                                            create_graph=True))
                    outer_grad = torch.nn.utils.parameters_to_vector(
                        torch.autograd.grad(outer_loss, params))
                    implicit_grad = self.neumann_approx(
                        in_grad, outer_grad, params)
                    grad_list.append(implicit_grad)
                    loss_list.append(outer_loss.item())

        if is_train:
            self.outer_optimizer.zero_grad()
            weight = torch.ones(len(grad_list))
            weight = weight / torch.sum(weight)
            grad = mix_grad(grad_list, weight)
            grad_log = apply_grad(self.network, grad)
            self.outer_optimizer.step()

            return loss_log, acc_log, grad_log
        else:
            return loss_log, acc_log
Esempio n. 2
0
    def outer_loop(self, batch, is_train):

        self.network.zero_grad()

        train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch(
            batch)

        loss_log = 0
        acc_log = 0
        grad_list = []
        loss_list = []

        for (train_input, train_target, test_input,
             test_target) in zip(train_inputs, train_targets, test_inputs,
                                 test_targets):
            with higher.innerloop_ctx(
                    self.network,
                    self.inner_optimizer,
                    track_higher_grads=is_train) as (fmodel, diffopt):

                for step in range(self.args.n_inner):
                    self.inner_loop(fmodel, diffopt, train_input, train_target)

                test_logit = fmodel(test_input)
                outer_loss = F.cross_entropy(test_logit, test_target)
                loss_log += outer_loss.item() / self.batch_size

                with torch.no_grad():
                    acc_log += get_accuracy(
                        test_logit, test_target).item() / self.batch_size

                if is_train:
                    outer_grad = torch.autograd.grad(outer_loss,
                                                     fmodel.parameters(time=0))
                    grad_list.append(outer_grad)
                    loss_list.append(outer_loss.item())

        if is_train:
            weight = torch.ones(len(grad_list))
            weight = weight / torch.sum(weight)
            grad = mix_grad(grad_list, weight)
            grad_log = apply_grad(self.network, grad)
            self.outer_optimizer.step()

            return loss_log, acc_log, grad_log
        else:
            return loss_log, acc_log
Esempio n. 3
0
    def outer_loop(self, batch, is_train):

        self.network.zero_grad()
        
        train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch(batch)

        loss_log = 0
        acc_log = 0
        grad_list = []
        loss_list = []

        for (train_input, train_target, test_input, test_target) in zip(train_inputs, train_targets, test_inputs, test_targets):
            override = self.inner_optimizer if is_train else None
            
            with higher.innerloop_ctx(self.network, self.inner_optimizer, track_higher_grads=False) as (fmodel, diffopt):

                for step in range(self.args.n_inner):
                    if is_train:
                        index = np.random.permutation(np.arange(len(test_input)))[:10]
                        train_input = test_input[index]
                        train_target = test_target[index]
                    self.inner_loop(fmodel, diffopt, train_input, train_target)
                
                with torch.no_grad():
                    test_logit = fmodel(test_input)
                    outer_loss = F.cross_entropy(test_logit, test_target)
                    loss_log += outer_loss.item()/self.batch_size
                    loss_list.append(outer_loss.item())
                    acc_log += get_accuracy(test_logit, test_target).item()/self.batch_size
            
                if is_train:
                    outer_grad = []
                    for p_0, p_T in zip(fmodel.parameters(time=0), fmodel.parameters(time=step)):
                        outer_grad.append(-(p_T - p_0).detach())
                    grad_list.append(outer_grad)

        if is_train:
            weight = torch.ones(len(grad_list))/len(grad_list)
            grad = mix_grad(grad_list, weight)
            grad_log = apply_grad(self.network, grad)
            self.outer_optimizer.step()

            return loss_log, acc_log, grad_log
        else:
            return loss_log, acc_log
    def outer_loop(self, batch, is_train):

        self.network.zero_grad()
        
        train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch(batch)

        loss_log = 0
        acc_log = 0
        grad_list = []
        loss_list = []

        for (train_input, train_target, test_input, test_target) in zip(train_inputs, train_targets, test_inputs, test_targets):
            context = torch.zeros(self.context_dim).cuda().requires_grad_()

            for step in range(self.args.n_inner):
                context = self.inner_loop(context, train_input, train_target, is_train)

            test_logit = self.cavia_forward(test_input, context)
            outer_loss = F.cross_entropy(test_logit, test_target)
            loss_log += outer_loss.item()/self.batch_size

            with torch.no_grad():
                acc_log += get_accuracy(test_logit, test_target).item()/self.batch_size
        
            if is_train:
                outer_grad = torch.autograd.grad(outer_loss, self.network.parameters())
                grad_list.append(outer_grad)
                loss_list.append(outer_loss.item())

        if is_train:
            weight = torch.ones(len(grad_list))
            weight = weight / torch.sum(weight)
            grad = mix_grad(grad_list, weight)
            grad_log = apply_grad(self.network, grad)

            self.outer_optimizer.step()
            return loss_log, acc_log, grad_log
        else:
            return loss_log, acc_log
Esempio n. 5
0
    def outer_loop(self, batch, reverse_dict_list, is_train):

        self.network.zero_grad()

        train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch(
            batch)

        loss_log = 0
        acc_log = 0
        grad_list = []
        loss_list = []
        for i, (train_input, train_target, test_input,
                test_target) in enumerate(
                    zip(train_inputs, train_targets, test_inputs,
                        test_targets)):
            self.network.init_decoder()
            inner_optimizer = torch.optim.SGD(self.network.decoder,
                                              lr=self.args.inner_lr)
            with higher.innerloop_ctx(
                    self.network, inner_optimizer,
                    track_higher_grads=is_train) as (fmodel, diffopt):

                fmodel(torch.zeros(1, 3, 80, 80).type(torch.float32).cuda())

                # Convert numerical label to word label
                target = [[
                    reverse_dict_list[i][j] for j in range(self.args.num_way)
                ]]

                # Get transformed word embeddings and lambda
                word_proto = fmodel.word_embedding(target,
                                                   is_train)[0].permute([1, 0])

                train_input = fmodel(train_input)
                for step in range(self.args.n_inner):
                    self.inner_loop(fmodel.decoder, diffopt, train_input,
                                    train_target, word_proto)

                test_logit = fmodel(test_input)
                test_logit = torch.matmul(test_logit, 2 * word_proto * fmodel.decoder[0]) \
                            - ((word_proto * fmodel.decoder[0])**2).sum(dim=0, keepdim=True) + fmodel.decoder[1]
                outer_loss = F.cross_entropy(test_logit, test_target)
                loss_log += outer_loss.item() / self.batch_size

                with torch.no_grad():
                    acc_log += get_accuracy(
                        test_logit, test_target).item() / self.batch_size

                if is_train:
                    params = fmodel.parameters(time=0)
                    outer_grad = torch.autograd.grad(outer_loss, params)
                    grad_list.append(outer_grad)
                    loss_list.append(outer_loss.item())

        # self._lambda = _lambda.detach()

        if is_train:
            weight = torch.ones(len(grad_list))
            weight = weight / torch.sum(weight)
            grad = mix_grad(grad_list, weight)
            grad_log = apply_grad(self.network, grad)
            self.outer_optimizer.step()

            return loss_log, acc_log, grad_log
        else:
            return loss_log, acc_log
Esempio n. 6
0
    def outer_loop(self, batch, is_train):

        train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch(
            batch)

        loss_log = 0
        dice_log = 0
        grad_list = []
        loss_list = []

        for (train_input, train_target, test_input,
             test_target) in zip(train_inputs, train_targets, test_inputs,
                                 test_targets):

            with higher.innerloop_ctx(self.network,
                                      self.inner_optimizer,
                                      track_higher_grads=False) as (fmodel,
                                                                    diffopt):

                for step in range(self.args.n_inner):
                    self.inner_loop(fmodel, diffopt, train_input, train_target)

                train_logit = fmodel(train_input)
                in_loss = bce_dice_loss(train_logit, train_target)

                test_logit = fmodel(test_input)

                outer_loss = bce_dice_loss(test_logit, test_target)
                loss_log += outer_loss.item() / self.batch_size

                out_cut = np.copy(test_logit.data.cpu().numpy())
                out_cut[np.nonzero(out_cut < 0.3)] = 0.0  #threshold
                out_cut[np.nonzero(out_cut >= 0.3)] = 1.0

                with torch.no_grad():
                    dice_log += dice_coef(out_cut,
                                          test_target.data.cpu().numpy()).item(
                                          ) / self.batch_size

                if is_train:
                    params = list(fmodel.parameters(time=-1))
                    in_grad = torch.nn.utils.parameters_to_vector(
                        torch.autograd.grad(in_loss, params,
                                            create_graph=True))
                    outer_grad = torch.nn.utils.parameters_to_vector(
                        torch.autograd.grad(outer_loss, params))
                    implicit_grad = self.cg(in_grad, outer_grad, params)
                    grad_list.append(implicit_grad)
                    loss_list.append(outer_loss.item())

        if is_train:
            self.outer_optimizer.zero_grad()
            weight = torch.ones(len(grad_list))
            weight = weight / torch.sum(weight)
            grad = mix_grad(grad_list, weight)
            grad_log = apply_grad(self.network, grad)
            self.outer_optimizer.step()

            return loss_log, dice_log, grad_log
        else:
            return loss_log, dice_log