Esempio n. 1
0
    def _forward_gfsl(self, support_embs, query_embs, seen_proto):
        num_dim = support_embs.shape[-1]
        unseenproto = support_embs.reshape(self.args.eval_shot, -1,
                                           num_dim).mean(dim=0)  # N x d

        logits_s = euclidean_metric(query_embs, seen_proto)
        logits_u = euclidean_metric(query_embs, unseenproto)
        return logits_s, logits_u
Esempio n. 2
0
    def forward_proto(self, data_shot, data_query, way=None):
        if way is None:
            way = self.args.num_class
        proto = self.encoder(data_shot)
        proto = proto.reshape(self.args.shot, way, -1).mean(dim=0)
        query = self.encoder(data_query)

        logits_dist = euclidean_metric(query, proto)
        logits_sim = torch.mm(query, F.normalize(proto, p=2, dim=-1).t())
        return logits_dist, logits_sim
Esempio n. 3
0
def do_pass(batches, shot, way, query, expressions, encoder):
    model, optimizer = expressions
    model.train()
    for i, batch in enumerate(batches, 1):
        data = [x[{'v1': 0, 'v2': 1}[encoder]] for x, _ in batch]
        p = shot * way
        data_shot, data_query = data[:p], data[p:]

        proto = model(data_shot, encoder=encoder)
        proto = proto.reshape(shot, way, -1).mean(dim=0)

        # ignore original labels, reassign labels from 0
        label = torch.arange(way).repeat(query)
        label = label.type(torch.LongTensor).to(device)

        logits = utils.euclidean_metric(model(data_query, encoder=encoder),
                                        proto)
        loss = F.cross_entropy(logits, label)
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()
    return loss.item()
Esempio n. 4
0
 def forward_many(self, data, seen_proto):
     return euclidean_metric(self.encoder(data), seen_proto)