Beispiel #1
0
    def forward(self, x_shot, x_query):
        shot_shape = x_shot.shape[:-3]
        query_shape = x_query.shape[:-3]
        img_shape = x_shot.shape[-3:]

        x_shot = x_shot.view(-1, *img_shape)
        print(x_shot.shape)
        x_query = x_query.view(-1, *img_shape)
        x_tot = self.encoder(torch.cat([x_shot, x_query], dim=0))
        x_shot, x_query = x_tot[:len(x_shot)], x_tot[-len(x_query):]
        print(x_shot.shape)
        x_shot = x_shot.view(*shot_shape, -1)
        print(x_shot.shape)
        x_query = x_query.view(*query_shape, -1)
        print(x_query.shape)

        if self.method == 'cos':
            x_shot = x_shot.mean(dim=-2)
            print(x_shot.shape)
            x_shot = F.normalize(x_shot, dim=-1)
            x_query = F.normalize(x_query, dim=-1)
            metric = 'dot'
        elif self.method == 'sqr':
            x_shot = x_shot.mean(dim=-2)
            metric = 'sqr'

        print(x_shot.shape)
        print(x_query.shape)

        logits = utils.compute_logits(x_query,
                                      x_shot,
                                      metric=metric,
                                      temp=self.temp)
        return logits
Beispiel #2
0
    def forward(self, x_shot, x_query, **kwargs):
        shot_shape = x_shot.shape[:-3]
        query_shape = x_query.shape[:-3]
        img_shape = x_shot.shape[-3:]

        x_shot = x_shot.view(-1, *img_shape)
        x_query = x_query.view(-1, *img_shape)
        x_tot = self.encoder(torch.cat([x_shot, x_query], dim=0))
        x_shot, x_query = x_tot[:len(x_shot)], x_tot[-len(x_query):]
        x_shot = x_shot.view(*shot_shape, -1)
        x_query = x_query.view(*query_shape, -1)

        if self.method == 'cos':
            x_shot = x_shot.mean(dim=-2)
            x_shot = F.normalize(x_shot,
                                 dim=-1)  # [ep_per_batch, way, feature_len]
            x_query = F.normalize(
                x_query, dim=-1)  # [ep_per_batch, way * query, feature_len]
            metric = 'dot'
        elif self.method == 'sqr':
            x_shot = x_shot.mean(dim=-2)
            metric = 'sqr'

        logits = utils.compute_logits(
            x_query, x_shot, metric=metric,
            temp=self.temp)  # [ep_per_batch, way * query, way]
        return logits
Beispiel #3
0
    def forward(self, x_shot, x_query, x_pseudo):
        shot_shape = x_shot.shape[:-3]
        query_shape = x_query.shape[:-3]
        pseudo_shape = x_pseudo.shape[:-3]
        img_shape = x_shot.shape[-3:]

        x_shot = x_shot.view(-1, *img_shape)
        x_query = x_query.view(-1, *img_shape)
        x_pseudo = x_pseudo.view(-1, *img_shape)
        x_tot = self.encoder(torch.cat([x_shot, x_query, x_pseudo], dim=0))
        x_shot, x_query, x_pseudo = x_tot[:len(x_shot)], x_tot[
            len(x_shot):len(x_shot) + len(x_query)], x_tot[len(x_shot) +
                                                           len(x_query):]
        x_shot = x_shot.view(*shot_shape, -1)
        x_query = x_query.view(*query_shape, -1)
        x_pseudo = x_pseudo.view(*pseudo_shape, -1)

        x_shot = torch.cat([x_pseudo, x_shot], dim=-2)
        x_shot = x_shot.mean(dim=-2)

        if self.method == 'cos':
            x_shot = F.normalize(x_shot, dim=-1)
            x_query = F.normalize(x_query, dim=-1)
            metric_ = 'dot'
        elif self.method == 'sqr':
            metric_ = 'sqr'

        logits = utils.compute_logits(x_query,
                                      x_shot,
                                      metric=metric_,
                                      temp=self.temp)

        return logits
Beispiel #4
0
    def forward(self, x_shot, x_query, x_pseudo):
        shot_shape = x_shot.shape[:-3]
        query_shape = x_query.shape[:-3]
        pseudo_shape = x_pseudo.shape[:-3]
        img_shape = x_shot.shape[-3:]

        x_shot = x_shot.view(-1, *img_shape)
        x_query = x_query.view(-1, *img_shape)
        x_pseudo = x_pseudo.view(-1, *img_shape)
        x_tot = self.encoder(torch.cat([x_shot, x_query, x_pseudo], dim=0))
        x_shot, x_query, x_pseudo = x_tot[:len(x_shot)], x_tot[
            len(x_shot):len(x_shot) + len(x_query)], x_tot[len(x_shot) +
                                                           len(x_query):]
        x_shot = x_shot.view(*shot_shape, *x_tot.shape[1:])
        x_query = x_query.view(*query_shape, *x_tot.shape[1:])
        x_pseudo = x_pseudo.view(*pseudo_shape, *x_tot.shape[1:])

        a_shot = self.aggregator(x_shot)
        a_pseudo = self.aggregator(x_pseudo)
        total = torch.cat((a_shot, a_pseudo), dim=-3)
        batch_shape = total.shape[:2]
        feat_shape = total.shape[2:]
        total = total.view(-1, *feat_shape)
        mask = self.masking_model(total)
        mask = mask.view(*batch_shape, *mask.shape[1:]).unsqueeze(dim=2)

        x_pseudo = torch.mul(x_pseudo, mask)
        img_shape = x_query.shape[-3:]
        x_shot = x_shot.view(-1, *img_shape)
        x_query = x_query.view(-1, *img_shape)
        x_pseudo = x_pseudo.view(-1, *img_shape)
        x_tot = self.universal(torch.cat([x_shot, x_query, x_pseudo], dim=0))
        x_shot, x_query, x_pseudo = x_tot[:len(x_shot)], x_tot[
            len(x_shot):len(x_shot) + len(x_query)], x_tot[len(x_shot) +
                                                           len(x_query):]
        x_shot = x_shot.view(*shot_shape, *x_tot.shape[1:])
        x_query = x_query.view(*query_shape, *x_tot.shape[1:])
        x_pseudo = x_pseudo.view(*pseudo_shape, *x_tot.shape[1:])

        x_shot_c = torch.cat([x_shot, x_pseudo],
                             dim=2).mean(2).view(*shot_shape[:2], -1)
        x_query = x_query.view(*query_shape, -1)
        x_shot = x_shot_c

        if self.method == 'cos':
            x_shot = F.normalize(x_shot, dim=-1)
            x_query = F.normalize(x_query, dim=-1)
            metric_ = 'dot'
        elif self.method == 'sqr':
            metric_ = 'sqr'

        logits = utils.compute_logits(x_query,
                                      x_shot,
                                      metric=metric_,
                                      temp=self.temp)

        return logits
Beispiel #5
0
    def forward(self, x_shot, y_shot, x_query, y_query):
        x_all = torch.cat([x_shot, x_query], dim=0)
        x_all = self.encoder(x_all)
        x_shot, x_query = x_all[:len(x_shot)], x_all[-len(x_query):]

        n_way = int(y_shot.max()) + 1
        proto = []
        for c in range(n_way):
            ind = []
            for i, y in enumerate(y_shot):
                if int(y) == c:
                    ind.append(i)
            proto.append(x_shot[ind].mean(dim=0))
        proto = torch.stack(proto)

        logits = utils.compute_logits(x_query,
                                      proto,
                                      metric='cos',
                                      temp=self.temp)
        loss = F.cross_entropy(logits, y_query)
        acc = utils.compute_acc(logits, y_query)

        return loss, acc
Beispiel #6
0
 def forward(self, x):
     return utils.compute_logits(x, self.proto, self.metric, self.temp)
Beispiel #7
0
    def forward_unlabel(self, x_shot, x_unlabel, x_query):
        shot_shape = x_shot.shape[:-3]
        unlabel_shape = x_unlabel.shape[:-3]
        query_shape = x_query.shape[:-3]
        img_shape = x_shot.shape[-3:]

        x_shot = x_shot.view(-1, *img_shape)
        x_unlabel = x_unlabel.view(-1, *img_shape)
        x_query = x_query.view(-1, *img_shape)
        x_tot = self.encoder(torch.cat([x_shot, x_unlabel, x_query], dim=0))
        x_shot, x_unlabel, x_query = x_tot[:len(x_shot)], x_tot[
            len(x_shot):len(x_shot) + len(x_unlabel)], x_tot[-len(x_query):]
        x_shot = x_shot.view(*shot_shape, -1)
        x_unlabel = x_unlabel.view(*unlabel_shape, -1)
        x_query = x_query.view(*query_shape, -1)

        if self.method == 'cos':
            x_shot_tmp = x_shot.mean(dim=-2)
            x_shot_tmp = F.normalize(x_shot_tmp, dim=-1)
            x_unlabel_tmp = F.normalize(x_unlabel, dim=-1)
            metric = 'dot'
        elif self.method == 'sqr':
            x_shot_tmp = x_shot.mean(dim=-2)
            x_unlabel_tmp = x_unlabel
            metric = 'sqr'

        logits = utils.compute_logits(x_unlabel_tmp,
                                      x_shot_tmp,
                                      metric=metric,
                                      temp=self.temp)

        prob_unlabel = F.softmax(logits, 2)

        B, n_way, n_shot = shot_shape
        prob_shot = torch.arange(n_way).unsqueeze(1).expand(n_way,
                                                            n_shot).reshape(-1)
        prob_shot = torch.zeros(n_way * n_shot,
                                n_way).scatter_(1, prob_shot.unsqueeze(1),
                                                1).repeat(B, 1, 1)
        prob_shot = prob_shot.cuda()
        x_shot = x_shot.view(B, n_way * n_shot, -1)

        x_all = torch.cat((x_shot, x_unlabel), 1)
        prob_all = torch.cat((prob_shot, prob_unlabel), 1)

        prob_sum = torch.sum(prob_all, dim=1, keepdim=True)
        prob = prob_all / prob_sum
        cluster_center = torch.sum(x_all.unsqueeze(2) * prob.unsqueeze(3),
                                   dim=1)

        if self.method == 'cos':
            cluster_center = F.normalize(cluster_center, dim=-1)
            x_query = F.normalize(x_query, dim=-1)
            metric = 'dot'
        elif self.method == 'sqr':
            metric = 'sqr'

        logits = utils.compute_logits(x_query,
                                      cluster_center,
                                      metric=metric,
                                      temp=self.temp)
        return logits