Exemplo n.º 1
0
    def __init__(self, model_func, n_way, n_support):
        super(GnnNet, self).__init__(model_func, n_way, n_support)

        # loss function
        self.loss_fn = nn.CrossEntropyLoss()
        self.first = True

        # metric function
        self.fc = nn.Sequential(nn.Linear(self.feat_dim, 128),
                                nn.BatchNorm1d(128, track_running_stats=False)
                                ) if not self.maml else nn.Sequential(
                                    backbone.Linear_fw(self.feat_dim, 128),
                                    backbone.BatchNorm1d_fw(
                                        128, track_running_stats=False))
        self.gnn = GNN_nl(128 + self.n_way, 96, self.n_way)
        self.method = 'GnnNet'
        self.n_support = round(self.n_support / 2)  ## average across 2

        # fix label for training the metric function   1*nw(1 + ns)*nw
        support_label = torch.from_numpy(
            np.repeat(range(self.n_way), self.n_support)).unsqueeze(1)
        support_label = torch.zeros(
            self.n_way * self.n_support,
            self.n_way).scatter(1, support_label,
                                1).view(self.n_way, self.n_support, self.n_way)
        #print(support_label.shape)
        support_label = torch.cat(
            [support_label, torch.zeros(self.n_way, 1, n_way)], dim=1)
        #print(support_label.shape)
        self.support_label = support_label.view(1, -1, self.n_way)
Exemplo n.º 2
0
    def __init__(self, model_func,  n_way, n_support, approx = False, jigsaw=False, \
                lbda=0.0, rotation=False, tracking=False, use_bn=True, pretrain=False):
        super(MAML, self).__init__(model_func,
                                   n_way,
                                   n_support,
                                   use_bn,
                                   pretrain,
                                   change_way=False)

        self.loss_fn = nn.CrossEntropyLoss()
        self.classifier = backbone.Linear_fw(self.feat_dim, n_way)
        self.classifier.bias.data.fill_(0)

        self.n_task = 4
        self.task_update_num = 5
        self.train_lr = 0.01
        self.approx = approx  #first order approx.

        self.global_count = 0
        self.jigsaw = jigsaw
        self.rotation = rotation
        self.lbda = lbda
        if self.jigsaw:
            self.fc6 = nn.Sequential()
            self.fc6.add_module('fc6_s1', backbone.Linear_fw(512,
                                                             512))  #for resnet
            self.fc6.add_module('relu6_s1', nn.ReLU(inplace=True))
            self.fc6.add_module('drop6_s1', nn.Dropout(p=0.5))
            self.fc6[0].bias.data.fill_(0)

            self.fc7 = nn.Sequential()
            self.fc7.add_module('fc7', backbone.Linear_fw(9 * 512,
                                                          4096))  #for resnet
            self.fc7.add_module('relu7', nn.ReLU(inplace=True))
            self.fc7.add_module('drop7', nn.Dropout(p=0.5))
            self.fc7[0].bias.data.fill_(0)

            self.classifier_jigsaw = nn.Sequential()
            self.classifier_jigsaw.add_module('fc8',
                                              backbone.Linear_fw(4096, 35))
            self.classifier_jigsaw[0].bias.data.fill_(0)
        if self.rotation:
            self.fc6 = nn.Sequential()
            self.fc6.add_module('fc6_s1', backbone.Linear_fw(512,
                                                             512))  #for resnet
            self.fc6.add_module('relu6_s1', nn.ReLU(inplace=True))
            self.fc6.add_module('drop6_s1', nn.Dropout(p=0.5))
            self.fc6[0].bias.data.fill_(0)

            self.fc7 = nn.Sequential()
            self.fc7.add_module('fc7', backbone.Linear_fw(512,
                                                          128))  #for resnet
            self.fc7.add_module('relu7', nn.ReLU(inplace=True))
            self.fc7.add_module('drop7', nn.Dropout(p=0.5))
            self.fc7[0].bias.data.fill_(0)

            self.classifier_rotation = nn.Sequential()
            self.classifier_rotation.add_module('fc8',
                                                backbone.Linear_fw(128, 4))
            self.classifier_rotation[0].bias.data.fill_(0)
Exemplo n.º 3
0
 def __init__(self,num_classes, im_size):
     super(ResNetClassifier, self).__init__()
     if im_size == 32:
         self.encoder = backbone.ResNet10_32(3) 
     elif im_size == 84:
         self.encoder = backbone.ResNet10_84(3) 
     else:
         raise ValueError
     self.classifier = backbone.Linear_fw(512, num_classes)
Exemplo n.º 4
0
    def __init__(self, feat_dim, n_way, update_step, approx=True, lr=0.01):
        super(MAMLBlock, self).__init__()
        self.in_dim = feat_dim
        self.hidden_dim = int(self.in_dim / 2)
        # self.hidden_dim = self.in_dim
        #self.out_dim = int(self.hidden_dim / 2)
        self.out_dim = self.hidden_dim * 2
        self.n_way = n_way
        self.update_step = update_step
        self.approx = approx
        self.lr = lr

        self.loss_fn = nn.CrossEntropyLoss()
        self.feature = nn.Sequential(backbone.Linear_fw(self.in_dim, self.hidden_dim),
                                     nn.ReLU(),
                                     backbone.Linear_fw(self.hidden_dim, self.out_dim))
        self.classifier = backbone.Linear_fw(self.out_dim, self.n_way)
        self.classifier.bias.data.fill_(0)
Exemplo n.º 5
0
    def __init__(self, model_func, lr=0.001, dropp=0.1, num_train_class=64):
        self.feature = model_func()
        self.loss_fn = nn.CrossEntropyLoss().cuda()
        self.dropout = nn.Dropout(p=dropp)
        self.classifier = backbone.Linear_fw(self.feature.final_feat_dim,
                                             num_train_class)
        self.classifier.bias.data.fill_(0)

        self.train_lr = lr
Exemplo n.º 6
0
    def __init__(self, model_func,  n_way, n_support, approx = False):
        super(MAML, self).__init__( model_func,  n_way, n_support, change_way = False)

        self.loss_fn = nn.CrossEntropyLoss()
        self.classifier = backbone.Linear_fw(self.feat_dim, n_way)
        self.classifier.bias.data.fill_(0)
        
        self.n_task     = 4
        self.task_update_num = 5
        self.train_lr = 0.01
        self.approx = approx #first order approx.        
Exemplo n.º 7
0
    def __init__(self, **kwargs):
        self.n_way = kwargs.pop('n_way')
        self.n_support = kwargs.pop('n_shot')
        self.task_update_num = kwargs.pop('task_update_num')
        self.use_support_stats = kwargs.pop('use_support_stats')
        super(MAML, self).__init__(**kwargs)
        approx = False
        self.feature = self.encoder
        x = torch.zeros([2] + kwargs['x_dim'])
        self.feat_dim = self.encoder(x).view(2,-1).shape[-1]

        self.loss_fn = nn.CrossEntropyLoss(reduction='none')
        self.classifier = backbone.Linear_fw(self.feat_dim, self.n_way)
        self.classifier.bias.data.fill_(0)

        self.train_lr = 0.01
        self.approx = approx
Exemplo n.º 8
0
    def __init__(self, model_func, n_way, n_support):
        super(DampNet, self).__init__(model_func, n_way, n_support)

        # loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # metric function
        self.gnn_dim = 128
        self.fc = nn.Sequential(nn.Linear(self.feat_dim, self.gnn_dim),
                                nn.BatchNorm1d(128, track_running_stats=False)
                                ) if not self.maml else nn.Sequential(
                                    backbone.Linear_fw(self.feat_dim, 128),
                                    backbone.BatchNorm1d_fw(
                                        128, track_running_stats=False))
        self.gnn = GNN_nl(self.gnn_dim + self.n_way, 96, self.n_way)
        self.method = 'DampNet'

        self.num_ex = 20  ##making change to 50?
        #self.meta_store_mean = torch.zeros((self.num_ex,self.feat_dim))
        #self.meta_store_std = torch.zeros((self.num_ex,self.n_support*self.n_way,self.feat_dim))
        #self.corruption = torch.from_numpy(np.diag(np.ones(self.feat_dim)))

        ### comparison / recovery network

        self.W_R = nn.Bilinear(self.feat_dim, self.feat_dim, 300,
                               bias=False).cuda()
        self.V_R = nn.Linear(self.feat_dim * 2, 300).cuda()

        self.W_R_std = nn.Bilinear(self.feat_dim,
                                   self.feat_dim,
                                   300,
                                   bias=False).cuda()
        self.V_R_std = nn.Linear(self.feat_dim * 2, 300).cuda()

        ## MLP
        self.tanh = nn.Tanh()
        self.layer1 = nn.Linear(300 * 2, 500)
        self.layer2 = nn.Linear(500, 500)
        self.layer3 = nn.Linear(500, self.feat_dim)
        self.layer1_add = nn.Linear(300 * 2, 500)
        self.layer2_add = nn.Linear(500, 500)
        self.layer3_add = nn.Linear(500, self.feat_dim)

        self.final_meta_prototype = torch.zeros(self.feat_dim)
        self.final_meta_prototype_std = torch.zeros(self.feat_dim)
        self.final_meta_prototypes_initialized = False
        self.final_all_feats = torch.zeros(
            5, 100, self.n_way * self.n_support,
            self.feat_dim)  ##replace first and second dim with desired

        #self.meta_prototype_mean = torch.mean((1, self.feat_dim))
        #self.meta_prototype_std = torch.mean((1, self.feat_dim))
        self.call_count = 150  ##if restart
        self.first = True

        # fix label for training the metric function   1*nw(1 + ns)*nw
        support_label = torch.from_numpy(
            np.repeat(range(self.n_way), self.n_support)).unsqueeze(1)
        support_label = torch.zeros(
            self.n_way * self.n_support,
            self.n_way).scatter(1, support_label,
                                1).view(self.n_way, self.n_support, self.n_way)
        support_label = torch.cat(
            [support_label, torch.zeros(self.n_way, 1, n_way)], dim=1)
        self.support_label = support_label.view(1, -1, self.n_way)

        self.cuda()