def _forward_gfsl(self, support_embs, query_embs, seen_proto): num_dim = support_embs.shape[-1] unseenproto = support_embs.reshape(self.args.eval_shot, -1, num_dim).mean(dim=0) # N x d logits_s = euclidean_metric(query_embs, seen_proto) logits_u = euclidean_metric(query_embs, unseenproto) return logits_s, logits_u
def forward_proto(self, data_shot, data_query, way=None): if way is None: way = self.args.num_class proto = self.encoder(data_shot) proto = proto.reshape(self.args.shot, way, -1).mean(dim=0) query = self.encoder(data_query) logits_dist = euclidean_metric(query, proto) logits_sim = torch.mm(query, F.normalize(proto, p=2, dim=-1).t()) return logits_dist, logits_sim
def do_pass(batches, shot, way, query, expressions, encoder): model, optimizer = expressions model.train() for i, batch in enumerate(batches, 1): data = [x[{'v1': 0, 'v2': 1}[encoder]] for x, _ in batch] p = shot * way data_shot, data_query = data[:p], data[p:] proto = model(data_shot, encoder=encoder) proto = proto.reshape(shot, way, -1).mean(dim=0) # ignore original labels, reassign labels from 0 label = torch.arange(way).repeat(query) label = label.type(torch.LongTensor).to(device) logits = utils.euclidean_metric(model(data_query, encoder=encoder), proto) loss = F.cross_entropy(logits, label) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item()
def forward_many(self, data, seen_proto): return euclidean_metric(self.encoder(data), seen_proto)