Beispiel #1
0
    def __init__(self, model_func, n_way, n_support, tf_path=None):
        super(GnnNet, self).__init__(model_func,
                                     n_way,
                                     n_support,
                                     tf_path=tf_path)

        # loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # metric function
        self.fc = nn.Sequential(
            nn.Linear(self.feat_dim, 128),
            nn.BatchNorm1d(128, track_running_stats=False)) if not (
                self.maml or self.maml_adain) else nn.Sequential(
                    backbone.Linear_fw(self.feat_dim, 128),
                    backbone.BatchNorm1d_fw(128, track_running_stats=False))
        self.gnn = GNN_nl(128 + self.n_way, 96, self.n_way)
        self.method = 'GnnNet'

        # fix label for training the metric function   1*nw(1 + ns)*nw
        support_label = torch.from_numpy(
            np.repeat(range(self.n_way), self.n_support)).unsqueeze(1)
        support_label = torch.zeros(
            self.n_way * self.n_support,
            self.n_way).scatter(1, support_label,
                                1).view(self.n_way, self.n_support, self.n_way)
        support_label = torch.cat(
            [support_label,
             torch.zeros(self.n_way, 1, self.n_way)], dim=1)
        self.support_label = support_label.view(1, -1, self.n_way)
Beispiel #2
0
    def __init__(self, model_func, n_way, n_support):
        super(GnnNet, self).__init__(model_func, n_way, n_support)
        # loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # metric function
        self.fc = nn.Sequential(nn.Linear(
            self.feat_dim, 128), nn.BatchNorm1d(
                128,
                track_running_stats=False)) if not self.FWT else nn.Sequential(
                    backbone.Linear_fw(self.feat_dim, 128),
                    backbone.BatchNorm1d_fw(128, track_running_stats=False))
        self.gnn = GNN_nl(128 + self.n_way, 96, self.n_way)
        self.method = 'GnnNet'
Beispiel #3
0
class Ours_gnn(MetaTemplate):
    maml = False

    def __init__(self,
                 model_func,
                 n_way,
                 n_support,
                 domain_specific,
                 fine_tune,
                 train_lr,
                 tf_path=None):
        super(Ours_gnn, self).__init__(model_func,
                                       n_way,
                                       n_support,
                                       domain_specific=domain_specific,
                                       fine_tune=fine_tune,
                                       train_lr=train_lr,
                                       tf_path=None)

        # loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # metric function
        self.fc = nn.Sequential(nn.Linear(self.feat_dim, 128),
                                nn.BatchNorm1d(128, track_running_stats=False)
                                ) if not self.maml else nn.Sequential(
                                    backbone.Linear_fw(self.feat_dim, 128),
                                    backbone.BatchNorm1d_fw(
                                        128, track_running_stats=False))
        self.gnn = GNN_nl(128 + self.n_way, 96, self.n_way)

        # fix label for training the metric function   1*nw(1 + ns)*nw
        support_label = torch.from_numpy(
            np.repeat(range(self.n_way), self.n_support)).unsqueeze(1)
        support_label = torch.zeros(
            self.n_way * self.n_support,
            self.n_way).scatter(1, support_label,
                                1).view(self.n_way, self.n_support, self.n_way)
        support_label = torch.cat(
            [support_label, torch.zeros(self.n_way, 1, n_way)], dim=1)
        self.support_label = support_label.view(1, -1, self.n_way)

    def cuda(self):
        self.feature.cuda()
        self.fc.cuda()
        self.gnn.cuda()
        self.support_label = self.support_label.cuda()
        return self

    def set_forward(self, x, is_feature=False):
        x = x.cuda()
        if is_feature:
            # reshape the feature tensor: n_way * n_s + 15 * f
            assert (x.size(1) == self.n_support + 15)
            z = self.fc(x.view(-1, *x.size()[2:]))
            z = z.view(self.n_way, -1, z.size(1))
        else:
            # get feature using encoder
            x = x.view(-1, *x.size()[2:])
            z = self.feature(x)
            z = self.fc(z.float())
            z = z.view(self.n_way, -1, z.size(1))

        # stack the feature for metric function: n_way * n_s + n_q * f -> n_q * [1 * n_way(n_s + 1) * f]
        z_stack = [
            torch.cat([
                z[:, :self.n_support],
                z[:, self.n_support + i:self.n_support + i + 1]
            ],
                      dim=1).view(1, -1, z.size(2))
            for i in range(self.n_query)
        ]
        assert (z_stack[0].size(1) == self.n_way * (self.n_support + 1))
        scores = self.forward_gnn(z_stack)
        return scores

    def forward_gnn(self, zs):
        # gnn inp: n_q * n_way(n_s + 1) * f
        nodes = torch.cat(
            [torch.cat([z, self.support_label], dim=2) for z in zs], dim=0)
        scores = self.gnn(nodes)

        # n_q * n_way(n_s + 1) * n_way -> (n_way * n_q) * n_way
        scores = scores.view(self.n_query, self.n_way,
                             self.n_support + 1, self.n_way)[:, :, -1].permute(
                                 1, 0, 2).contiguous().view(-1, self.n_way)
        return scores
Beispiel #4
0
class GnnNet(MetaTemplate):
    maml = False
    maml_adain = False
    assert (maml and maml_adain) == False

    def __init__(self, model_func, n_way, n_support, tf_path=None):
        super(GnnNet, self).__init__(model_func,
                                     n_way,
                                     n_support,
                                     tf_path=tf_path)

        # loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # metric function
        self.fc = nn.Sequential(
            nn.Linear(self.feat_dim, 128),
            nn.BatchNorm1d(128, track_running_stats=False)) if not (
                self.maml or self.maml_adain) else nn.Sequential(
                    backbone.Linear_fw(self.feat_dim, 128),
                    backbone.BatchNorm1d_fw(128, track_running_stats=False))
        self.gnn = GNN_nl(128 + self.n_way, 96, self.n_way)
        self.method = 'GnnNet'

        # fix label for training the metric function   1*nw(1 + ns)*nw
        support_label = torch.from_numpy(
            np.repeat(range(self.n_way), self.n_support)).unsqueeze(1)
        support_label = torch.zeros(
            self.n_way * self.n_support,
            self.n_way).scatter(1, support_label,
                                1).view(self.n_way, self.n_support, self.n_way)
        support_label = torch.cat(
            [support_label,
             torch.zeros(self.n_way, 1, self.n_way)], dim=1)
        self.support_label = support_label.view(1, -1, self.n_way)

    def cuda(self):
        self.feature.cuda()
        self.fc.cuda()
        self.gnn.cuda()
        self.support_label = self.support_label.cuda()
        return self

    def set_forward(self, x, is_feature=False):
        x = x.cuda()

        if is_feature:
            # reshape the feature tensor: n_way * n_s + 15 * f
            assert (x.size(1) == self.n_support + 15)
            z = self.fc(x.view(-1, *x.size()[2:]))
            z = z.view(self.n_way, -1, z.size(1))
        else:
            # get feature using encoder
            x = x.view(-1, *x.size()[2:])
            z = self.fc(
                self.feature(x)
            )  # further encode the image features into a feature vector (128, )
            z = z.view(
                self.n_way, -1, z.size(1)
            )  # reshape to  (num_way, num_support + num_querry, feature_dim)

        # stack the feature for metric function: n_way * n_s + n_q * f -> n_q * [1 * n_way(n_s + 1) * f]   each query  image feature is concatenated with the features of the support images.
        z_stack = [
            torch.cat([
                z[:, :self.n_support],
                z[:, self.n_support + i:self.n_support + i + 1]
            ],
                      dim=1).view(1, -1, z.size(2))
            for i in range(self.n_query)
        ]
        assert (z_stack[0].size(1) == self.n_way * (self.n_support + 1))
        scores = self.forward_gnn(z_stack)

        return scores

    def forward_gnn(self, zs):
        # gnn inp: n_q * n_way(n_s + 1) * f
        nodes = torch.cat(
            [torch.cat([z, self.support_label], dim=2) for z in zs], dim=0)
        scores = self.gnn(nodes)

        # n_q * n_way(n_s + 1) * n_way -> (n_way * n_q) * n_way
        scores = scores.view(self.n_query, self.n_way,
                             self.n_support + 1, self.n_way)[:, :, -1].permute(
                                 1, 0, 2).contiguous().view(-1, self.n_way)
        return scores

    def set_forward_loss(self, x, epoch=None):
        y_query = torch.from_numpy(np.repeat(range(self.n_way), self.n_query))
        y_query = y_query.cuda()
        scores = self.set_forward(x)
        loss = self.loss_fn(scores, y_query)
        return scores, loss
Beispiel #5
0
    def __init__(self, model_func, n_way, n_support):
        super(DampNet, self).__init__(model_func, n_way, n_support)

        # loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # metric function
        self.gnn_dim = 128
        self.fc = nn.Sequential(nn.Linear(self.feat_dim, self.gnn_dim),
                                nn.BatchNorm1d(128, track_running_stats=False)
                                ) if not self.maml else nn.Sequential(
                                    backbone.Linear_fw(self.feat_dim, 128),
                                    backbone.BatchNorm1d_fw(
                                        128, track_running_stats=False))
        self.gnn = GNN_nl(self.gnn_dim + self.n_way, 96, self.n_way)
        self.method = 'DampNet'

        self.num_ex = 20  ##making change to 50?
        #self.meta_store_mean = torch.zeros((self.num_ex,self.feat_dim))
        #self.meta_store_std = torch.zeros((self.num_ex,self.n_support*self.n_way,self.feat_dim))
        #self.corruption = torch.from_numpy(np.diag(np.ones(self.feat_dim)))

        ### comparison / recovery network

        self.W_R = nn.Bilinear(self.feat_dim, self.feat_dim, 300,
                               bias=False).cuda()
        self.V_R = nn.Linear(self.feat_dim * 2, 300).cuda()

        self.W_R_std = nn.Bilinear(self.feat_dim,
                                   self.feat_dim,
                                   300,
                                   bias=False).cuda()
        self.V_R_std = nn.Linear(self.feat_dim * 2, 300).cuda()

        ## MLP
        self.tanh = nn.Tanh()
        self.layer1 = nn.Linear(300 * 2, 500)
        self.layer2 = nn.Linear(500, 500)
        self.layer3 = nn.Linear(500, self.feat_dim)
        self.layer1_add = nn.Linear(300 * 2, 500)
        self.layer2_add = nn.Linear(500, 500)
        self.layer3_add = nn.Linear(500, self.feat_dim)

        self.final_meta_prototype = torch.zeros(self.feat_dim)
        self.final_meta_prototype_std = torch.zeros(self.feat_dim)
        self.final_meta_prototypes_initialized = False
        self.final_all_feats = torch.zeros(
            5, 100, self.n_way * self.n_support,
            self.feat_dim)  ##replace first and second dim with desired

        #self.meta_prototype_mean = torch.mean((1, self.feat_dim))
        #self.meta_prototype_std = torch.mean((1, self.feat_dim))
        self.call_count = 150  ##if restart
        self.first = True

        # fix label for training the metric function   1*nw(1 + ns)*nw
        support_label = torch.from_numpy(
            np.repeat(range(self.n_way), self.n_support)).unsqueeze(1)
        support_label = torch.zeros(
            self.n_way * self.n_support,
            self.n_way).scatter(1, support_label,
                                1).view(self.n_way, self.n_support, self.n_way)
        support_label = torch.cat(
            [support_label, torch.zeros(self.n_way, 1, n_way)], dim=1)
        self.support_label = support_label.view(1, -1, self.n_way)

        self.cuda()
Beispiel #6
0
class DampNet(MetaTemplate):
    maml = False

    def __init__(self, model_func, n_way, n_support):
        super(DampNet, self).__init__(model_func, n_way, n_support)

        # loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # metric function
        self.gnn_dim = 128
        self.fc = nn.Sequential(nn.Linear(self.feat_dim, self.gnn_dim),
                                nn.BatchNorm1d(128, track_running_stats=False)
                                ) if not self.maml else nn.Sequential(
                                    backbone.Linear_fw(self.feat_dim, 128),
                                    backbone.BatchNorm1d_fw(
                                        128, track_running_stats=False))
        self.gnn = GNN_nl(self.gnn_dim + self.n_way, 96, self.n_way)
        self.method = 'DampNet'

        self.num_ex = 20  ##making change to 50?
        #self.meta_store_mean = torch.zeros((self.num_ex,self.feat_dim))
        #self.meta_store_std = torch.zeros((self.num_ex,self.n_support*self.n_way,self.feat_dim))
        #self.corruption = torch.from_numpy(np.diag(np.ones(self.feat_dim)))

        ### comparison / recovery network

        self.W_R = nn.Bilinear(self.feat_dim, self.feat_dim, 300,
                               bias=False).cuda()
        self.V_R = nn.Linear(self.feat_dim * 2, 300).cuda()

        self.W_R_std = nn.Bilinear(self.feat_dim,
                                   self.feat_dim,
                                   300,
                                   bias=False).cuda()
        self.V_R_std = nn.Linear(self.feat_dim * 2, 300).cuda()

        ## MLP
        self.tanh = nn.Tanh()
        self.layer1 = nn.Linear(300 * 2, 500)
        self.layer2 = nn.Linear(500, 500)
        self.layer3 = nn.Linear(500, self.feat_dim)
        self.layer1_add = nn.Linear(300 * 2, 500)
        self.layer2_add = nn.Linear(500, 500)
        self.layer3_add = nn.Linear(500, self.feat_dim)

        self.final_meta_prototype = torch.zeros(self.feat_dim)
        self.final_meta_prototype_std = torch.zeros(self.feat_dim)
        self.final_meta_prototypes_initialized = False
        self.final_all_feats = torch.zeros(
            5, 100, self.n_way * self.n_support,
            self.feat_dim)  ##replace first and second dim with desired

        #self.meta_prototype_mean = torch.mean((1, self.feat_dim))
        #self.meta_prototype_std = torch.mean((1, self.feat_dim))
        self.call_count = 150  ##if restart
        self.first = True

        # fix label for training the metric function   1*nw(1 + ns)*nw
        support_label = torch.from_numpy(
            np.repeat(range(self.n_way), self.n_support)).unsqueeze(1)
        support_label = torch.zeros(
            self.n_way * self.n_support,
            self.n_way).scatter(1, support_label,
                                1).view(self.n_way, self.n_support, self.n_way)
        support_label = torch.cat(
            [support_label, torch.zeros(self.n_way, 1, n_way)], dim=1)
        self.support_label = support_label.view(1, -1, self.n_way)

        self.cuda()

    def cuda(self):
        self.feature.cuda()
        self.fc.cuda()
        self.gnn.cuda()
        #self.meta_store_mean.cuda()
        #self.meta_store_std.cuda()
        self.W_R.cuda()
        self.V_R.cuda()
        self.W_R_std.cuda()
        self.V_R_std.cuda()
        self.tanh.cuda()
        self.layer1.cuda()
        self.layer2.cuda()
        self.layer3.cuda()
        self.layer1_add.cuda()
        self.layer2_add.cuda()
        self.layer3_add.cuda()
        self.final_all_feats.cuda()
        self.final_meta_prototype.cuda()
        self.final_meta_prototype_std.cuda()
        self.support_label = self.support_label.cuda()
        return self

    def get_all_feat(self, all_feat):
        all_feat = all_feat.cuda().detach()
        self.final_meta_prototype = torch.mean(all_feat,
                                               axis=0).cuda().detach()
        self.final_meta_prototype_std = all_feat.std(axis=0).cuda().detach()
        self.final_meta_prototypes_initialized = True
        return self

    def set_forward(self, x, is_feature=False, domain_shift=False):
        x = x.cuda()
        if domain_shift == False:
            if is_feature:
                # reshape the feature tensor: n_way * n_s + 15 * f
                assert (x.size(1) == self.n_support + 15)
                z = self.fc(x.view(-1, *x.size()[2:]))
                z = z.view(self.n_way, -1, z.size(1))
            else:
                # get feature using encoder ## brought it to higher level
                x2 = x.view(self.n_way, -1, x.size(1))
                x_mean = torch.mean(x2[:, :self.n_support, :],
                                    axis=(0, 1)).detach()
                x_std = x2[:, :self.n_support, :].std(axis=(0, 1)).detach()

            # stack the feature for metric function: n_way * n_s + n_q * f -> n_q * [1 * n_way(n_s + 1) * f]
            if self.final_meta_prototypes_initialized == False:
                self.fc[0].weight.requires_grad = True
                self.fc[0].bias.requires_grad = True
                self.gnn = self.gnn.train()

                z = self.fc(x)
                z = z.view(self.n_way, -1, z.size(1))
                #z_mean = torch.mean(z[:,:self.n_support,:], axis = (0,1), keepdim = True)
                #print("AVG SHAPE")
                #z = z - z_mean
                #z_norm = torch.norm(z, dim = 2, keepdim = True)
                #z = z / z_norm
                z_stack = [
                    torch.cat([
                        z[:, :self.n_support],
                        z[:, self.n_support + i:self.n_support + i + 1]
                    ],
                              dim=1).view(1, -1, z.size(2))
                    for i in range(self.n_query)
                ]
                assert (z_stack[0].size(1) == self.n_way *
                        (self.n_support + 1))
                scores = self.forward_gnn(z_stack)
                idx = self.call_count % self.num_ex
                #self.meta_store_mean[idx] =  x_mean
                #self.meta_store_std[idx] = x2[:,:self.n_support,:].detach().reshape(-1,self.feat_dim)
                self.call_count += 1
                return scores
            elif self.call_count % 2 == 1:
                ### corruption vector
                a = 0.5
                b = 0.8
                perc = (b - a) * np.random.random_sample() + a
                perc_zeros = perc / 2
                a2 = 1.5
                b2 = 4
                m_fac = (b2 - a2) * np.random.random_sample() + a2
                meta_prototype_mean = self.final_meta_prototype
                meta_prototype_std = self.final_meta_prototype_std
                one_zeros = np.concatenate(
                    (np.ones(self.feat_dim -
                             math.floor(self.feat_dim * perc_zeros)),
                     np.zeros(math.floor(self.feat_dim * perc_zeros))))
                np.random.shuffle(one_zeros)
                corruption = torch.from_numpy(
                    np.diag(one_zeros)).cuda().float()
                corruption_bias = torch.from_numpy(np.zeros(
                    self.feat_dim)).cuda().float()
                temp = np.asarray(list(range(0, self.feat_dim)))
                random_idx = np.random.choice(temp,
                                              math.floor(perc * self.feat_dim))
                random_idx2 = np.random.choice(
                    temp, math.floor(perc * self.feat_dim))
                rand_idx_col = np.random.choice(random_idx2, 1)
                ad_sub = np.concatenate(
                    (np.ones(self.feat_dim - math.floor(self.feat_dim * 0.5)),
                     -np.ones(math.floor(self.feat_dim * 0.5))))
                np.random.shuffle(ad_sub)
                t_sample = m_fac * np.reshape(
                    np.random.standard_t(5, self.feat_dim * self.feat_dim),
                    (self.feat_dim, self.feat_dim))
                t_sample_bias = np.random.standard_t(5, self.feat_dim) + ad_sub
                t_sample_bias = torch.from_numpy(
                    -np.squeeze(t_sample[:, rand_idx_col]) +
                    t_sample_bias).cuda().float()
                t_sample = torch.from_numpy(t_sample).cuda().float()
                corruption[random_idx, random_idx2] += t_sample[random_idx,
                                                                random_idx2]
                corruption_bias[random_idx2] += t_sample_bias[random_idx2]
                corrupt_x = torch.matmul(
                    x, corruption).detach().cuda()  ## new input
                corrupt_x += (m_fac * corruption_bias)
                corrupt_x2 = corrupt_x.view(self.n_way, -1, x.size(1))
                corrupt_x_mean = torch.mean(corrupt_x2[:, :self.n_support, :],
                                            axis=(0, 1)).detach()
                corrupt_x_std = corrupt_x2[:, :self.n_support, :].std(
                    axis=(0, 1)).detach()

                W_out_m = self.W_R(meta_prototype_mean, corrupt_x_mean).cuda()
                V_out_m = self.V_R(
                    torch.cat((meta_prototype_mean, corrupt_x_mean)))
                NTN_out = W_out_m + V_out_m

                W_out_m_std = self.W_R_std(meta_prototype_std.cuda(),
                                           corrupt_x_std.cuda()).cuda()
                V_out_m_std = self.V_R_std(
                    torch.cat((meta_prototype_std.cuda(),
                               corrupt_x_std.cuda())).cuda())
                NTN_out_std = W_out_m_std + V_out_m_std

                compare_input = self.tanh(torch.cat((NTN_out, NTN_out_std)))
                mult_ = F.relu(self.layer1(compare_input))
                mult_ = F.relu(self.layer2(mult_))
                #upper = torch.tensor([1]).float().cuda()
                #lower = torch.tensor([1]).float().cuda()
                mult_ = self.layer3(mult_)

                add_ = F.relu(self.layer1_add(compare_input))
                add_ = F.relu(self.layer2_add(add_))
                add_ = self.layer3_add(add_)  ## sparse add

                recovered_x = torch.mul(corrupt_x.detach(), mult_) + add_
                self.fc[0].weight.requires_grad = False
                self.fc[0].bias.requires_grad = False
                self.gnn = self.gnn.eval()
                r_z = self.fc(recovered_x)

                r_z = r_z.view(self.n_way, -1, r_z.size(1))
                #r_z_mean = torch.mean(r_z[:,:self.n_support,:], axis = (0,1), keepdim = True)
                #print("AVG SHAPE")
                #r_z = r_z - r_z_mean
                #r_z_norm = torch.norm(r_z, dim = 2, keepdim = True)
                #r_z = r_z / r_z_norm
                r_z_stack = [
                    torch.cat([
                        r_z[:, :self.n_support],
                        r_z[:, self.n_support + i:self.n_support + i + 1]
                    ],
                              dim=1).view(1, -1, r_z.size(2))
                    for i in range(self.n_query)
                ]
                assert (r_z_stack[0].size(1) == self.n_way *
                        (self.n_support + 1))

                r_scores = self.forward_gnn(r_z_stack)
                scores = r_scores
                #self.meta_store_mean[idx] =  x_mean
                #self.meta_store_std[idx] = x2[:,:self.n_support,:].detach().reshape(-1,self.feat_dim)
                self.call_count += 1
                return scores
            elif self.call_count % 2 == 0:
                meta_prototype_mean = self.final_meta_prototype
                meta_prototype_std = self.final_meta_prototype_std
                self.fc[0].weight.requires_grad = True
                self.fc[0].bias.requires_grad = True
                self.gnn = self.gnn.train()

                W_out_m = self.W_R(meta_prototype_mean, x_mean.detach()).cuda()
                V_out_m = self.V_R(
                    torch.cat((meta_prototype_mean, x_mean.detach())))
                NTN_out = W_out_m + V_out_m

                W_out_m_std = self.W_R_std(meta_prototype_std.cuda(),
                                           x_std.cuda()).cuda()
                V_out_m_std = self.V_R_std(
                    torch.cat(
                        (meta_prototype_std.cuda(), x_std.cuda())).cuda())
                NTN_out_std = W_out_m_std + V_out_m_std

                compare_input = self.tanh(torch.cat((NTN_out, NTN_out_std)))
                mult_ = F.relu(self.layer1(compare_input))
                mult_ = F.relu(self.layer2(mult_))
                mult_ = self.layer3(mult_)

                add_ = F.relu(self.layer1_add(compare_input))
                add_ = F.relu(self.layer2_add(add_))
                #thresh_add = (add_ > 1).float() * 1
                add_ = self.layer3_add(add_)

                recovered_x = torch.mul(x, mult_) + add_  ### use back normal x
                r_z = self.fc(recovered_x)

                r_z = r_z.view(self.n_way, -1, r_z.size(1))
                #r_z_mean = torch.mean(r_z[:,:self.n_support,:], axis = (0,1), keepdim = True)
                #print("AVG SHAPE")
                #r_z = r_z - r_z_mean
                #r_z_norm = torch.norm(r_z, dim = 2, keepdim = True)
                #r_z = r_z / r_z_norm
                r_z_stack = [
                    torch.cat([
                        r_z[:, :self.n_support],
                        r_z[:, self.n_support + i:self.n_support + i + 1]
                    ],
                              dim=1).view(1, -1, r_z.size(2))
                    for i in range(self.n_query)
                ]
                assert (r_z_stack[0].size(1) == self.n_way *
                        (self.n_support + 1))
                r_scores = self.forward_gnn(r_z_stack)
                scores = r_scores
                idx = self.call_count % self.num_ex
                #self.meta_store_mean[idx] =  x_mean
                #self.meta_store_std[idx] = x2[:,:self.n_support,:].detach().reshape(-1,self.feat_dim)
                self.call_count += 1
                return scores
        elif domain_shift == True:
            #print("DOMAIN SHIFT")
            if is_feature:
                assert (x.size(1) == self.n_support + 15)
                x = x.view(-1, *x.size()[2:])
                x2 = x.view(self.n_way, -1, x.size(1))
                ### LOAD PROTOTYPES
                meta_prototype_mean = self.final_meta_prototype
                meta_prototype_std = self.final_meta_prototype_std
                x_mean = torch.mean(x2[:, :self.n_support, :],
                                    axis=(0, 1)).detach()
                x_std = x2[:, :self.n_support, :].std(axis=(0, 1)).detach()

                W_out_m = self.W_R(meta_prototype_mean, x_mean).cuda()
                V_out_m = self.V_R(torch.cat((meta_prototype_mean, x_mean)))
                NTN_out = W_out_m + V_out_m

                W_out_m_std = self.W_R_std(meta_prototype_std.cuda(),
                                           x_std.cuda()).cuda()
                V_out_m_std = self.V_R_std(
                    torch.cat(
                        (meta_prototype_std.cuda(), x_std.cuda())).cuda())
                NTN_out_std = W_out_m_std + V_out_m_std

                compare_input = self.tanh(torch.cat((NTN_out, NTN_out_std)))
                mult_ = F.relu(self.layer1(compare_input))
                mult_ = F.relu(self.layer2(mult_))
                mult_ = self.layer3(mult_)

                add_ = F.relu(self.layer1_add(compare_input))
                add_ = F.relu(self.layer2_add(add_))
                add_ = self.layer3_add(add_)

                recovered_x = torch.mul(x, mult_) + add_  ### use back normal x
                r_z = self.fc(recovered_x)

                r_z = r_z.view(self.n_way, -1, r_z.size(1))
                #r_z_mean = torch.mean(r_z[:,:self.n_support,:], axis = (0,1), keepdim = True)
                #print("AVG SHAPE")
                #r_z = r_z - r_z_mean
                #r_z_norm = torch.norm(r_z, dim = 2, keepdim = True)
                # r_z = r_z / r_z_norm
                r_z_stack = [
                    torch.cat([
                        r_z[:, :self.n_support],
                        r_z[:, self.n_support + i:self.n_support + i + 1]
                    ],
                              dim=1).view(1, -1, r_z.size(2))
                    for i in range(self.n_query)
                ]
                assert (r_z_stack[0].size(1) == self.n_way *
                        (self.n_support + 1))
                r_scores = self.forward_gnn(r_z_stack)
                scores = r_scores
                idx = self.call_count % self.num_ex

                return scores
            else:
                print("NOT IMPLEMENTED YET")

    def set_forward_unsup(self,
                          x,
                          x_u_mean,
                          x_u_std,
                          is_feature=False,
                          domain_shift=True):
        x = x.cuda()
        assert (domain_shift == True)
        if domain_shift == True:
            #print("DOMAIN SHIFT")
            if is_feature:
                assert (x.size(1) == self.n_support + 15)
                x = x.view(-1, *x.size()[2:])
                #x2 = x.view(self.n_way, -1, x.size(1))
                ### LOAD PROTOTYPES
                meta_prototype_mean = self.final_meta_prototype
                meta_prototype_std = self.final_meta_prototype_std
                #x_mean = torch.mean(x2[:,:self.n_support,:], axis = (0,1)).detach()
                #x_std = x2[:,:self.n_support,:].std(axis = (0,1)).detach()

                W_out_m = self.W_R(meta_prototype_mean, x_u_mean).cuda()
                V_out_m = self.V_R(torch.cat((meta_prototype_mean, x_u_mean)))
                NTN_out = W_out_m + V_out_m

                W_out_m_std = self.W_R_std(meta_prototype_std.cuda(),
                                           x_u_std.cuda()).cuda()
                V_out_m_std = self.V_R_std(
                    torch.cat(
                        (meta_prototype_std.cuda(), x_u_std.cuda())).cuda())
                NTN_out_std = W_out_m_std + V_out_m_std

                compare_input = self.tanh(torch.cat((NTN_out, NTN_out_std)))
                mult_ = F.relu(self.layer1(compare_input))
                mult_ = F.relu(self.layer2(mult_))
                mult_ = self.layer3(mult_)

                add_ = F.relu(self.layer1_add(compare_input))
                add_ = F.relu(self.layer2_add(add_))
                add_ = self.layer3_add(add_)

                recovered_x = torch.mul(x, mult_) + add_  ### use back normal x
                r_z = self.fc(recovered_x)

                r_z = r_z.view(self.n_way, -1, r_z.size(1))
                #r_z_mean = torch.mean(r_z[:,:self.n_support,:], axis = (0,1), keepdim = True)
                #print("AVG SHAPE")
                #r_z = r_z - r_z_mean
                #r_z_norm = torch.norm(r_z, dim = 2, keepdim = True)
                # r_z = r_z / r_z_norm
                r_z_stack = [
                    torch.cat([
                        r_z[:, :self.n_support],
                        r_z[:, self.n_support + i:self.n_support + i + 1]
                    ],
                              dim=1).view(1, -1, r_z.size(2))
                    for i in range(self.n_query)
                ]
                assert (r_z_stack[0].size(1) == self.n_way *
                        (self.n_support + 1))
                r_scores = self.forward_gnn(r_z_stack)
                scores = r_scores
                idx = self.call_count % self.num_ex

                return scores

        else:
            print("NOT IMPLEMENTED")

    def forward_gnn(self, zs):
        # gnn inp: n_q * n_way(n_s + 1) * f
        nodes = torch.cat(
            [torch.cat([z, self.support_label], dim=2) for z in zs], dim=0)
        scores = self.gnn(nodes)

        # n_q * n_way(n_s + 1) * n_way -> (n_way * n_q) * n_way
        scores = scores.view(self.n_query, self.n_way,
                             self.n_support + 1, self.n_way)[:, :, -1].permute(
                                 1, 0, 2).contiguous().view(-1, self.n_way)
        return scores

    def set_forward_loss(self, x):
        y_query = torch.from_numpy(np.repeat(range(self.n_way), self.n_query))
        y_query = y_query.cuda()
        scores = self.set_forward(x)
        loss = self.loss_fn(scores, y_query)
        return loss

    def train_loop_full(self, epoch, train_loader, optimizer, final_epoch):
        print_freq = 10
        num_reset = 5
        avg_loss = 0
        start = 206

        if epoch == 208:
            print(self.final_all_feats)
            print(self.W_R)
            print(self.V_R)
            print(self.W_R_std)
            print(self.V_R_std)

        for i, (x, _) in enumerate(train_loader):
            self.n_query = x.size(1) - self.n_support
            if self.change_way:
                self.n_way = x.size(0)
            optimizer.zero_grad()
            x = self.feature(x.view(-1, *x.size()[2:]).cuda())
            loss = self.set_forward_loss(x)
            loss.backward()
            optimizer.step()
            avg_loss = avg_loss + loss.item()
            feats = x.detach()
            feats = feats.view(self.n_way, -1, feats.size(1))
            feats = feats[:, :self.n_support, :]
            #print(feats.shape)
            feats = feats.reshape(self.n_way * self.n_support, self.feat_dim)
            if i % print_freq == 0:
                #print(optimizer.state_dict()['param_groups'][0]['lr'])
                print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(
                    epoch, i, len(train_loader), avg_loss / float(i + 1)))
            if i == 0:
                all_feats = torch.zeros(len(train_loader), feats.shape[0],
                                        feats.shape[1])
            all_feats[i] = feats
        #if epoch % 5 == 0:
        self.final_all_feats[(epoch % 5)] = all_feats
        if epoch >= start:
            self = self.get_all_feat(
                self.final_all_feats.view(
                    5 * len(train_loader) * feats.shape[0], feats.shape[1]))
        if epoch == (final_epoch - 1):
            proto_numpy = self.final_meta_prototype.detach().cpu().numpy()
            proto_numpy_std = self.final_meta_prototype_std.detach().cpu(
            ).numpy()
            name1 = "proto_numpy_" + str(epoch) + ".npy"
            name2 = "proto_numpy_std_" + str(epoch) + ".npy"
            np.save(name1, proto_numpy)
            np.save(name2, proto_numpy)

    def set_forward_adaptation_full(
        self,
        x,
        is_feature=True
    ):  #further adaptation, default is fixing feature and train a new softmax clasifier
        assert is_feature == True, 'Feature is fixed in further adaptation'
        x = x.cuda()
        original_x = x.cuda()
        x = x.view(-1, *x.size()[2:])
        x2 = x.view(self.n_way, -1, x.size(1))
        ### LOAD PROTOTYPES
        meta_prototype_mean = self.final_meta_prototype
        meta_prototype_std = self.final_meta_prototype_std
        x_mean = torch.mean(x2[:, :self.n_support, :], axis=(0, 1)).detach()
        x_std = x2[:, :self.n_support, :].std(axis=(0, 1)).detach()

        W_out_m = self.W_R(meta_prototype_mean, x_mean).cuda()
        V_out_m = self.V_R(torch.cat((meta_prototype_mean, x_mean)))
        NTN_out = W_out_m + V_out_m

        W_out_m_std = self.W_R_std(meta_prototype_std.cuda(),
                                   x_std.cuda()).cuda()
        V_out_m_std = self.V_R_std(
            torch.cat((meta_prototype_std.cuda(), x_std.cuda())).cuda())
        NTN_out_std = W_out_m_std + V_out_m_std

        compare_input = self.tanh(torch.cat((NTN_out, NTN_out_std)))
        mult_ = F.relu(self.layer1(compare_input))
        mult_ = F.relu(self.layer2(mult_))
        mult_ = self.layer3(mult_)

        add_ = F.relu(self.layer1_add(compare_input))
        add_ = F.relu(self.layer2_add(add_))
        add_ = self.layer3_add(add_)

        recovered_x = torch.mul(original_x,
                                mult_) + add_  ### use back normal x
        z_support, z_query = self.parse_feature(recovered_x.detach(),
                                                is_feature)

        z_support = z_support.contiguous().view(self.n_way * self.n_support,
                                                -1)
        z_query = z_query.contiguous().view(self.n_way * self.n_query, -1)

        y_support = torch.from_numpy(
            np.repeat(range(self.n_way), self.n_support))
        y_support = Variable(y_support.cuda())

        linear_clf = nn.Linear(self.feat_dim, self.n_way)
        linear_clf = linear_clf.cuda()

        set_optimizer = torch.optim.SGD(linear_clf.parameters(),
                                        lr=0.01,
                                        momentum=0.9,
                                        dampening=0.9,
                                        weight_decay=0.001)

        loss_function = nn.CrossEntropyLoss()
        loss_function = loss_function.cuda()

        batch_size = 4
        support_size = self.n_way * self.n_support
        for epoch in range(100):
            rand_id = np.random.permutation(support_size)
            for i in range(0, support_size, batch_size):
                set_optimizer.zero_grad()
                selected_id = torch.from_numpy(
                    rand_id[i:min(i + batch_size, support_size)]).cuda()
                z_batch = z_support[selected_id]
                y_batch = y_support[selected_id]
                scores = linear_clf(z_batch)
                loss = loss_function(scores, y_batch)
                loss.backward()
                set_optimizer.step()

        scores = linear_clf(z_query)
        return scores
Beispiel #7
0
class GnnNet(MetaTemplate):
  maml=False
  def __init__(self, model_func,  n_way, n_support):
    super(GnnNet, self).__init__(model_func, n_way, n_support)

    # loss function
    self.loss_fn = nn.CrossEntropyLoss()
    self.first = True

    # metric function
    self.fc = nn.Sequential(nn.Linear(self.feat_dim, 128), nn.BatchNorm1d(128, track_running_stats=False)) if not self.maml else nn.Sequential(backbone.Linear_fw(self.feat_dim, 128), backbone.BatchNorm1d_fw(128, track_running_stats=False))
    self.gnn = GNN_nl(128 + self.n_way, 96, self.n_way)
    self.method = 'GnnNet'

    # fix label for training the metric function   1*nw(1 + ns)*nw
    support_label = torch.from_numpy(np.repeat(range(self.n_way), self.n_support)).unsqueeze(1)
    support_label = torch.zeros(self.n_way*self.n_support, self.n_way).scatter(1, support_label, 1).view(self.n_way, self.n_support, self.n_way)
    support_label = torch.cat([support_label, torch.zeros(self.n_way, 1, n_way)], dim=1)
    self.support_label = support_label.view(1, -1, self.n_way)

  def cuda(self):
    self.feature.cuda()
    self.fc.cuda()
    self.gnn.cuda()
    self.support_label = self.support_label.cuda()
    return self

  def set_forward(self,x,is_feature=False):
    x = x.cuda()

    if is_feature:
      # reshape the feature tensor: n_way * n_s + 15 * f
      assert(x.size(1) == self.n_support + 15)
      z = self.fc(x.view(-1, *x.size()[2:]))
      z = z.view(self.n_way, -1, z.size(1))
    else:
      # get feature using encoder
      x = x.view(-1, *x.size()[2:])
      z = self.fc(self.feature(x))
      z = z.view(self.n_way, -1, z.size(1))

    # stack the feature for metric function: n_way * n_s + n_q * f -> n_q * [1 * n_way(n_s + 1) * f]
    z_stack = [torch.cat([z[:, :self.n_support], z[:, self.n_support + i:self.n_support + i + 1]], dim=1).view(1, -1, z.size(2)) for i in range(self.n_query)]
    
    assert(z_stack[0].size(1) == self.n_way*(self.n_support + 1))
    scores = self.forward_gnn(z_stack)
    return scores

  def set_forward(self,x,is_feature=False):
    x = x.cuda()

    if is_feature:
      # reshape the feature tensor: n_way * n_s + 15 * f
      assert(x.size(1) == self.n_support + 15)
      z = self.fc(x.view(-1, *x.size()[2:]))
      z = z.view(self.n_way, -1, z.size(1))
    else:
      # get feature using encoder
      x = x.view(-1, *x.size()[2:])
      z = self.fc(self.feature(x))
      z = z.view(self.n_way, -1, z.size(1))

    # stack the feature for metric function: n_way * n_s + n_q * f -> n_q * [1 * n_way(n_s + 1) * f]
    z_stack = [torch.cat([z[:, :self.n_support], z[:, self.n_support + i:self.n_support + i + 1]], dim=1).view(1, -1, z.size(2)) for i in range(self.n_query)]
    
    assert(z_stack[0].size(1) == self.n_way*(self.n_support + 1))
    scores = self.forward_gnn(z_stack)
    return scores


  def MAML_update(self):
    names = []
    for name, param in self.feature.named_parameters():
      if param.requires_grad:
        #print(name)
        names.append(name)
    
    names_sub = names[:-9]
    if not self.first:
      for (name, param), (name1, param1), (name2, param2) in zip(self.feature.named_parameters(), self.feature2.named_parameters(), self.feature3.named_parameters()):
        if name not in names_sub:
          dat_change = param2.data - param1.data ### Y - X
          new_dat = param.data - dat_change ### (Y- V) - (Y-X) = X-V
          param.data.copy_(new_dat)

  
  def set_forward_finetune(self,x,is_feature=False):
    x = x.cuda()

    
    # get feature using encoder
    batch_size = 4
    support_size = self.n_way * self.n_support 

    for name, param  in self.feature.named_parameters():
      param.requires_grad = True

    x_var = Variable(x)
      
    y_a_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_support ) )).cuda() # (25,)

    #print(y_a_i)
    self.MAML_update() ## call MAML update
    
    x_b_i = x_var[:, self.n_support:,:,:,:].contiguous().view( self.n_way* self.n_query,   *x.size()[2:]) 
    x_a_i = x_var[:,:self.n_support,:,:,:].contiguous().view( self.n_way* self.n_support, *x.size()[2:]) # (25, 3, 224, 224)
    feat_network = copy.deepcopy(self.feature)
    classifier = Classifier(self.feat_dim, self.n_way)
    delta_opt = torch.optim.Adam(filter(lambda p: p.requires_grad, feat_network.parameters()), lr = 0.01)
    loss_fn = nn.CrossEntropyLoss().cuda() ##change this code up ## dorop n way
    classifier_opt = torch.optim.Adam(classifier.parameters(), lr = 0.01, weight_decay=0.001) ##try it with weight_decay
    
    names = []
    for name, param in feat_network.named_parameters():
      if param.requires_grad:
        #print(name)
        names.append(name)
    
    names_sub = names[:-9] ### last Resnet block can adapt

    for name, param in feat_network.named_parameters():
      if name in names_sub:
        param.requires_grad = False    
  
      
    total_epoch = 15

    classifier.train()
    feat_network.train()

    classifier.cuda()
    feat_network.cuda()

    for epoch in range(total_epoch):
          rand_id = np.random.permutation(support_size)

          for j in range(0, support_size, batch_size):
              classifier_opt.zero_grad()
              
              delta_opt.zero_grad()

              #####################################
              selected_id = torch.from_numpy( rand_id[j: min(j+batch_size, support_size)]).cuda()
              
              z_batch = x_a_i[selected_id]
              y_batch = y_a_i[selected_id] 
              #####################################

              output = feat_network(z_batch)
              scores  = classifier(output)
              loss = loss_fn(output, y_batch)
              #grad = torch.autograd.grad(set_loss, fast_parameters, create_graph=True)

              #####################################
              loss.backward() ### think about how to compute gradients and achieve a good initialization

              classifier_opt.step()
              delta_opt.step()
    

    #feat_network.eval() ## fix this
    #classifier.eval()
    #self.train() ## continue training this!
    if self.first == True:
      self.first = False
    self.feature2 = copy.deepcopy(self.feature)
    self.feature3 = copy.deepcopy(feat_network) ## before the new state_dict is copied over
    self.feature.load_state_dict(feat_network.state_dict())
    
    for name, param  in self.feature.named_parameters():
        param.requires_grad = True
    
    output_support = self.feature(x_a_i.cuda()).view(self.n_way, self.n_support, -1)
    output_query = self.feature(x_b_i.cuda()).view(self.n_way,self.n_query,-1)

    final = torch.cat((output_support, output_query), dim =1).cuda()
    #print(x.size(1))
    #print(x.shape)
    assert(final.size(1) == self.n_support + 16) ##16 query samples in each batch
    z = self.fc(final.view(-1, *final.size()[2:]))
    z = z.view(self.n_way, -1, z.size(1))

    z_stack = [torch.cat([z[:, :self.n_support], z[:, self.n_support + i:self.n_support + i + 1]], dim=1).view(1, -1, z.size(2)) for i in range(self.n_query)]
    
    assert(z_stack[0].size(1) == self.n_way*(self.n_support + 1))
    
    scores = self.forward_gnn(z_stack)
    
    return scores

  def forward_gnn(self, zs):
    # gnn inp: n_q * n_way(n_s + 1) * f
    nodes = torch.cat([torch.cat([z, self.support_label], dim=2) for z in zs], dim=0)
    scores = self.gnn(nodes)

    # n_q * n_way(n_s + 1) * n_way -> (n_way * n_q) * n_way
    scores = scores.view(self.n_query, self.n_way, self.n_support + 1, self.n_way)[:, :, -1].permute(1, 0, 2).contiguous().view(-1, self.n_way)
    return scores

  def set_forward_loss(self, x):
    y_query = torch.from_numpy(np.repeat(range( self.n_way ), self.n_query))
    y_query = y_query.cuda()
    scores = self.set_forward(x)
    loss = self.loss_fn(scores, y_query)
    return loss

  def set_forward_loss_finetune(self, x):
    y_query = torch.from_numpy(np.repeat(range( self.n_way ), self.n_query))
    y_query = y_query.cuda()
    scores = self.set_forward_finetune(x)
    loss = self.loss_fn(scores, y_query)
    return loss
class DampNet(MetaTemplate):
    maml = False

    def __init__(self, model_func, n_way, n_support):
        super(DampNet, self).__init__(model_func, n_way, n_support)

        # loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # metric function
        self.gnn_dim = 128
        self.fc = nn.Sequential(nn.Linear(self.feat_dim, self.gnn_dim),
                                nn.BatchNorm1d(128, track_running_stats=False)
                                ) if not self.maml else nn.Sequential(
                                    backbone.Linear_fw(self.feat_dim, 128),
                                    backbone.BatchNorm1d_fw(
                                        128, track_running_stats=False))
        self.gnn = GNN_nl(self.gnn_dim + self.n_way, 96, self.n_way)
        self.method = 'DampNet'

        self.num_ex = 20  ##making change to 50?
        self.meta_store_mean = torch.zeros((self.num_ex, self.feat_dim))
        self.meta_store_std = torch.zeros(
            (self.num_ex, self.n_support * self.n_way, self.feat_dim))
        #self.corruption = torch.from_numpy(np.diag(np.ones(self.feat_dim)))

        ### comparison / recovery network

        self.W_R = nn.Bilinear(self.feat_dim, self.feat_dim, 500,
                               bias=False).cuda()
        self.V_R = nn.Linear(self.feat_dim * 2, 500).cuda()

        self.W_R_std = nn.Bilinear(self.feat_dim,
                                   self.feat_dim,
                                   500,
                                   bias=False).cuda()
        self.V_R_std = nn.Linear(self.feat_dim * 2, 500).cuda()

        ## MLP
        self.tanh = nn.Tanh()
        self.layer1 = nn.Linear(500 * 2, 900)
        self.layer2 = nn.Linear(900, 800)
        self.layer3 = nn.Linear(800, self.feat_dim)
        self.layer1_add = nn.Linear(500 * 2, 900)
        self.layer2_add = nn.Linear(900, 800)
        self.layer3_add = nn.Linear(800, self.feat_dim)

        self.final_meta_prototype = torch.zeros(self.feat_dim)
        self.final_meta_prototype_std = torch.zeros(self.feat_dim)
        self.final_meta_prototypes_initialized = False

        #self.meta_prototype_mean = torch.mean((1, self.feat_dim))
        #self.meta_prototype_std = torch.mean((1, self.feat_dim))
        self.call_count = 150  ##if restart
        self.first = True

        # fix label for training the metric function   1*nw(1 + ns)*nw
        support_label = torch.from_numpy(
            np.repeat(range(self.n_way), self.n_support)).unsqueeze(1)
        support_label = torch.zeros(
            self.n_way * self.n_support,
            self.n_way).scatter(1, support_label,
                                1).view(self.n_way, self.n_support, self.n_way)
        support_label = torch.cat(
            [support_label, torch.zeros(self.n_way, 1, n_way)], dim=1)
        self.support_label = support_label.view(1, -1, self.n_way)

        self.cuda()

    def cuda(self):
        self.feature.cuda()
        self.fc.cuda()
        self.gnn.cuda()
        self.meta_store_mean.cuda()
        self.meta_store_std.cuda()
        self.W_R.cuda()
        self.V_R.cuda()
        self.W_R_std.cuda()
        self.V_R_std.cuda()
        self.tanh.cuda()
        self.layer1.cuda()
        self.layer2.cuda()
        self.layer3.cuda()
        self.layer1_add.cuda()
        self.layer2_add.cuda()
        self.layer3_add.cuda()
        self.final_meta_prototype.cuda()
        self.final_meta_prototype_std.cuda()
        self.support_label = self.support_label.cuda()
        return self

    def get_all_feat(self, all_feat):
        all_feat = all_feat.cuda().detach()
        self.final_meta_prototype = torch.mean(all_feat,
                                               axis=0).cuda().detach()
        self.final_meta_prototype_std = all_feat.std(axis=0).cuda().detach()
        self.final_meta_prototypes_initialized = True
        return self

    def set_forward(self, x, is_feature=False, domain_shift=False):
        x = x.cuda()
        if domain_shift == False:
            if is_feature:
                # reshape the feature tensor: n_way * n_s + 15 * f
                assert (x.size(1) == self.n_support + 15)
                z = self.fc(x.view(-1, *x.size()[2:]))
                z = z.view(self.n_way, -1, z.size(1))
            else:
                # get feature using encoder
                x = self.feature(x.view(-1, *x.size()[2:]))
                x2 = x.view(self.n_way, -1, x.size(1))
                x_mean = torch.mean(x2[:, :self.n_support, :],
                                    axis=(0, 1)).detach()

                #print(self.call_count)
                if self.call_count == 151 or not self.training:
                    self.first = False
                    print(self.W_R)
                    print(self.V_R)
                    print(self.W_R_std)
                    print(self.V_R_std)

                ## standard dev shape

            # stack the feature for metric function: n_way * n_s + n_q * f -> n_q * [1 * n_way(n_s + 1) * f]
            if self.first == True:
                z = self.fc(x)
                z = z.view(self.n_way, -1, z.size(1))
                z_mean = torch.mean(z[:, :self.n_support, :],
                                    axis=(0, 1),
                                    keepdim=True)
                #print("AVG SHAPE")
                z = z - z_mean
                z_norm = torch.norm(z, dim=2, keepdim=True)
                z = z / z_norm
                z_stack = [
                    torch.cat([
                        z[:, :self.n_support],
                        z[:, self.n_support + i:self.n_support + i + 1]
                    ],
                              dim=1).view(1, -1, z.size(2))
                    for i in range(self.n_query)
                ]
                assert (z_stack[0].size(1) == self.n_way *
                        (self.n_support + 1))
                scores = self.forward_gnn(z_stack)
                idx = self.call_count % self.num_ex
                self.meta_store_mean[idx] = x_mean
                self.meta_store_std[idx] = x2[:, :self.n_support, :].detach(
                ).reshape(-1, self.feat_dim)
                self.call_count += 1
                return scores
            elif self.call_count % 2 != 0:
                ### corruption vector
                a = 0.5
                b = 0.7
                perc = 0.6  #(b - a) * np.random.random_sample() + a
                perc_zeros = perc / 2
                a2 = 1
                b2 = 3
                m_fac = 1.5  #(b2 - a2) * np.random.random_sample() + a2
                meta_prototype_mean = torch.mean(self.meta_store_mean,
                                                 axis=0).cuda().detach()
                meta_prototype_std = self.meta_store_std.std(
                    axis=(0, 1)).detach()
                one_zeros = np.concatenate(
                    (np.ones(self.feat_dim -
                             math.floor(self.feat_dim * perc_zeros)),
                     np.zeros(math.floor(self.feat_dim * perc_zeros))))
                np.random.shuffle(one_zeros)
                corruption = torch.from_numpy(
                    np.diag(one_zeros)).cuda().float()
                corruption_bias = torch.from_numpy(np.zeros(
                    self.feat_dim)).cuda().float()
                temp = np.asarray(list(range(0, self.feat_dim)))
                random_idx = np.random.choice(temp,
                                              math.floor(perc * self.feat_dim))
                random_idx2 = np.random.choice(
                    temp, math.floor(perc * self.feat_dim))
                rand_idx_col = np.random.choice(random_idx2, 1)
                ad_sub = np.concatenate(
                    (np.ones(self.feat_dim - math.floor(self.feat_dim * 0.5)),
                     -np.ones(math.floor(self.feat_dim * 0.5))))
                np.random.shuffle(ad_sub)
                t_sample = m_fac * np.reshape(
                    np.random.standard_t(5, self.feat_dim * self.feat_dim),
                    (self.feat_dim, self.feat_dim))
                t_sample_bias = np.random.standard_t(5, self.feat_dim) + ad_sub
                t_sample_bias = torch.from_numpy(
                    -np.squeeze(t_sample[:, rand_idx_col]) +
                    t_sample_bias).cuda().float()
                t_sample = torch.from_numpy(t_sample).cuda().float()
                corruption[random_idx, random_idx2] += t_sample[random_idx,
                                                                random_idx2]
                corruption_bias[random_idx2] += t_sample_bias[random_idx2]
                corrupt_x = torch.matmul(
                    x, corruption).detach().cuda()  ## new input
                corrupt_x += corruption_bias
                corrupt_x2 = corrupt_x.view(self.n_way, -1, x.size(1))
                corrupt_x_mean = torch.mean(corrupt_x2[:, :self.n_support, :],
                                            axis=(0, 1)).detach()
                corrupt_x_std = corrupt_x2[:, :self.n_support, :].std(
                    axis=(0, 1)).detach()
                if self.call_count == 154:
                    print(corruption)
                    print(corruption[random_idx, random_idx2])
                    print(corruption_bias)
                    print(corruption_bias[random_idx2])
                W_out_m = self.W_R(meta_prototype_mean, corrupt_x_mean).cuda()
                V_out_m = self.V_R(
                    torch.cat((meta_prototype_mean, corrupt_x_mean)))
                NTN_out = W_out_m + V_out_m

                W_out_m_std = self.W_R_std(meta_prototype_std.cuda(),
                                           corrupt_x_std.cuda()).cuda()
                V_out_m_std = self.V_R_std(
                    torch.cat((meta_prototype_std.cuda(),
                               corrupt_x_std.cuda())).cuda())
                NTN_out_std = W_out_m_std + V_out_m_std

                compare_input = self.tanh(torch.cat((NTN_out, NTN_out_std)))
                mult_ = F.relu(self.layer1(compare_input))
                mult_ = F.relu(self.layer2(mult_))
                mult_ = self.layer3(mult_)

                add_ = F.relu(self.layer1_add(compare_input))
                add_ = F.relu(self.layer2_add(add_))
                add_ = self.layer3_add(add_)  ## sparse add

                recovered_x = torch.mul(corrupt_x, mult_) + add_
                r_z = self.fc(recovered_x)

                r_z = r_z.view(self.n_way, -1, r_z.size(1))
                r_z_mean = torch.mean(r_z[:, :self.n_support, :],
                                      axis=(0, 1),
                                      keepdim=True)
                #print("AVG SHAPE")
                r_z = r_z - r_z_mean
                r_z_norm = torch.norm(r_z, dim=2, keepdim=True)
                r_z = r_z / r_z_norm
                r_z_stack = [
                    torch.cat([
                        r_z[:, :self.n_support],
                        r_z[:, self.n_support + i:self.n_support + i + 1]
                    ],
                              dim=1).view(1, -1, r_z.size(2))
                    for i in range(self.n_query)
                ]
                assert (r_z_stack[0].size(1) == self.n_way *
                        (self.n_support + 1))
                r_scores = self.forward_gnn(r_z_stack)
                scores = r_scores
                idx = self.call_count % self.num_ex
                self.meta_store_mean[idx] = x_mean
                self.meta_store_std[idx] = x2[:, :self.n_support, :].detach(
                ).reshape(-1, self.feat_dim)
                self.call_count += 1
                return scores
            elif self.call_count % 2 == 0:
                meta_prototype_mean = torch.mean(self.meta_store_mean,
                                                 axis=0).cuda().detach()
                meta_prototype_std = self.meta_store_std.std(
                    axis=(0, 1)).detach()
                x_std = x2[:, :self.n_support, :].std(axis=(0, 1)).detach()

                W_out_m = self.W_R(meta_prototype_mean, x_mean.detach()).cuda()
                V_out_m = self.V_R(
                    torch.cat((meta_prototype_mean, x_mean.detach())))
                NTN_out = W_out_m + V_out_m

                W_out_m_std = self.W_R_std(meta_prototype_std.cuda(),
                                           x_std.cuda()).cuda()
                V_out_m_std = self.V_R_std(
                    torch.cat(
                        (meta_prototype_std.cuda(), x_std.cuda())).cuda())
                NTN_out_std = W_out_m_std + V_out_m_std

                compare_input = self.tanh(torch.cat((NTN_out, NTN_out_std)))
                mult_ = F.relu(self.layer1(compare_input))
                mult_ = F.relu(self.layer2(mult_))
                mult_ = self.layer3(mult_)

                add_ = F.relu(self.layer1_add(compare_input))
                add_ = F.relu(self.layer2_add(add_))
                add_ = self.layer3_add(add_)

                recovered_x = torch.mul(x, mult_) + add_  ### use back normal x
                r_z = self.fc(recovered_x)

                r_z = r_z.view(self.n_way, -1, r_z.size(1))
                r_z_mean = torch.mean(r_z[:, :self.n_support, :],
                                      axis=(0, 1),
                                      keepdim=True)
                #print("AVG SHAPE")
                r_z = r_z - r_z_mean
                r_z_norm = torch.norm(r_z, dim=2, keepdim=True)
                r_z = r_z / r_z_norm
                r_z_stack = [
                    torch.cat([
                        r_z[:, :self.n_support],
                        r_z[:, self.n_support + i:self.n_support + i + 1]
                    ],
                              dim=1).view(1, -1, r_z.size(2))
                    for i in range(self.n_query)
                ]
                assert (r_z_stack[0].size(1) == self.n_way *
                        (self.n_support + 1))
                r_scores = self.forward_gnn(r_z_stack)
                scores = r_scores
                idx = self.call_count % self.num_ex
                self.meta_store_mean[idx] = x_mean
                self.meta_store_std[idx] = x2[:, :self.n_support, :].detach(
                ).reshape(-1, self.feat_dim)
                self.call_count += 1
                return scores
        elif domain_shift == True:
            if is_feature:
                assert (x.size(1) == self.n_support + 15)
                x = x.view(-1, *x.size()[2:])
                x2 = x.view(self.n_way, -1, x.size(1))
                ### LOAD PROTOTYPES
                meta_prototype_mean = self.final_meta_prototype
                meta_prototype_std = self.final_meta_prototype_std
                x_mean = torch.mean(x2[:, :self.n_support, :],
                                    axis=(0, 1)).detach()
                x_std = x2[:, :self.n_support, :].std(axis=(0, 1)).detach()

                W_out_m = self.W_R(meta_prototype_mean, x_mean).cuda()
                V_out_m = self.V_R(torch.cat((meta_prototype_mean, x_mean)))
                NTN_out = W_out_m + V_out_m

                W_out_m_std = self.W_R_std(meta_prototype_std.cuda(),
                                           x_std.cuda()).cuda()
                V_out_m_std = self.V_R_std(
                    torch.cat(
                        (meta_prototype_std.cuda(), x_std.cuda())).cuda())
                NTN_out_std = W_out_m_std + V_out_m_std

                compare_input = self.tanh(torch.cat((NTN_out, NTN_out_std)))
                mult_ = F.relu(self.layer1(compare_input))
                mult_ = F.relu(self.layer2(mult_))
                mult_ = self.layer3(mult_)

                add_ = F.relu(self.layer1_add(compare_input))
                add_ = F.relu(self.layer2_add(add_))
                add_ = self.layer3_add(add_)

                recovered_x = torch.mul(x, mult_) + add_  ### use back normal x
                r_z = self.fc(recovered_x)

                r_z = r_z.view(self.n_way, -1, r_z.size(1))
                r_z_mean = torch.mean(r_z[:, :self.n_support, :],
                                      axis=(0, 1),
                                      keepdim=True)
                #print("AVG SHAPE")
                r_z = r_z - r_z_mean
                r_z_norm = torch.norm(r_z, dim=2, keepdim=True)
                r_z = r_z / r_z_norm
                r_z_stack = [
                    torch.cat([
                        r_z[:, :self.n_support],
                        r_z[:, self.n_support + i:self.n_support + i + 1]
                    ],
                              dim=1).view(1, -1, r_z.size(2))
                    for i in range(self.n_query)
                ]
                assert (r_z_stack[0].size(1) == self.n_way *
                        (self.n_support + 1))
                r_scores = self.forward_gnn(r_z_stack)
                scores = r_scores
                return scores
            else:
                print("NOT IMPLEMENTED YET")

    def forward_gnn(self, zs):
        # gnn inp: n_q * n_way(n_s + 1) * f
        nodes = torch.cat(
            [torch.cat([z, self.support_label], dim=2) for z in zs], dim=0)
        scores = self.gnn(nodes)

        # n_q * n_way(n_s + 1) * n_way -> (n_way * n_q) * n_way
        scores = scores.view(self.n_query, self.n_way,
                             self.n_support + 1, self.n_way)[:, :, -1].permute(
                                 1, 0, 2).contiguous().view(-1, self.n_way)
        return scores

    def set_forward_loss(self, x):
        y_query = torch.from_numpy(np.repeat(range(self.n_way), self.n_query))
        y_query = y_query.cuda()
        scores = self.set_forward(x)
        loss = self.loss_fn(scores, y_query)
        return loss