コード例 #1
0
    def __init__(self, input_size, hidden_size, loss_type='mse'):
        super(RelationModule, self).__init__()

        self.loss_type = loss_type
        padding = 1 if (input_size[1] < 10) and (
            input_size[2] < 10
        ) else 0  # when using Resnet, conv map without avgpooling is 7x7, need padding in block to do pooling

        self.layer1 = RelationConvBlock(input_size[0] * 2,
                                        input_size[0],
                                        padding=padding)
        self.layer2 = RelationConvBlock(input_size[0],
                                        input_size[0],
                                        padding=padding)

        shrink_s = lambda s: int((int(
            (s - 2 + 2 * padding) / 2) - 2 + 2 * padding) / 2)

        if self.maml:
            self.fc1 = backbone.Linear_fw(
                input_size[0] * shrink_s(input_size[1]) *
                shrink_s(input_size[2]), hidden_size)
            self.fc2 = backbone.Linear_fw(hidden_size, 1)
        else:
            self.fc1 = nn.Linear(
                input_size[0] * shrink_s(input_size[1]) *
                shrink_s(input_size[2]), hidden_size)
            self.fc2 = nn.Linear(hidden_size, 1)
コード例 #2
0
    def __init__(self,
                 model_func,
                 n_way,
                 n_support,
                 tf_path=None,
                 approx=False,
                 dropout_method='none',
                 dropout_rate=0.,
                 dropout_schedule='constant'):
        super(MAML, self).__init__(model_func,
                                   n_way,
                                   n_support,
                                   tf_path=tf_path,
                                   change_way=False)

        self.loss_fn = nn.CrossEntropyLoss()
        self.classifier = backbone.Linear_fw(self.feat_dim, n_way)
        self.classifier.bias.data.fill_(0)

        self.batch_size = 4
        self.task_update_num = 5
        self.train_lr = 0.01
        self.approx = approx  #first order approx.

        self.dropout = DropGrad(dropout_method, dropout_rate, dropout_schedule)
        self.optimizer = torch.optim.Adam(self.parameters())
コード例 #3
0
    def __init__(self, model_func, n_way, n_support, tf_path=None):
        super(GnnNet, self).__init__(model_func,
                                     n_way,
                                     n_support,
                                     tf_path=tf_path)

        # loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # metric function
        self.fc = nn.Sequential(
            nn.Linear(self.feat_dim, 128),
            nn.BatchNorm1d(128, track_running_stats=False)) if not (
                self.maml or self.maml_adain) else nn.Sequential(
                    backbone.Linear_fw(self.feat_dim, 128),
                    backbone.BatchNorm1d_fw(128, track_running_stats=False))
        self.gnn = GNN_nl(128 + self.n_way, 96, self.n_way)
        self.method = 'GnnNet'

        # fix label for training the metric function   1*nw(1 + ns)*nw
        support_label = torch.from_numpy(
            np.repeat(range(self.n_way), self.n_support)).unsqueeze(1)
        support_label = torch.zeros(
            self.n_way * self.n_support,
            self.n_way).scatter(1, support_label,
                                1).view(self.n_way, self.n_support, self.n_way)
        support_label = torch.cat(
            [support_label,
             torch.zeros(self.n_way, 1, self.n_way)], dim=1)
        self.support_label = support_label.view(1, -1, self.n_way)
コード例 #4
0
    def __init__(self, input_size, output_size):
        super(ClassifierModule, self).__init__()

        if self.maml:
            self.classifier = backbone.Linear_fw(input_size, output_size)
            self.classifier.bias.data.fill_(0)
        else:
            self.classifier = nn.Linear(input_size, output_size)
            self.classifier.bias.data.fill_(0)
コード例 #5
0
ファイル: gnnnet.py プロジェクト: Haoqing-Wang/CDFSL-ATA
    def __init__(self, model_func, n_way, n_support):
        super(GnnNet, self).__init__(model_func, n_way, n_support)
        # loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # metric function
        self.fc = nn.Sequential(nn.Linear(
            self.feat_dim, 128), nn.BatchNorm1d(
                128,
                track_running_stats=False)) if not self.FWT else nn.Sequential(
                    backbone.Linear_fw(self.feat_dim, 128),
                    backbone.BatchNorm1d_fw(128, track_running_stats=False))
        self.gnn = GNN_nl(128 + self.n_way, 96, self.n_way)
        self.method = 'GnnNet'