Esempio n. 1
0
    def forward(self, bundle, model, device):
        batch = bundle_part_to_batch(bundle)
        batch = tuple(t.to(device) for t in batch)
        hop_loss, ans_loss, semantics = model(*batch) # Shape of semantics: [num_para, hidden_size]
        # pdb.set_trace()
        num_additional_nodes = len(bundle.additional_nodes)

        if num_additional_nodes > 0:
            max_length_additional = max([len(x) for x in bundle.additional_nodes])
            ids = torch.zeros((num_additional_nodes, max_length_additional), dtype = torch.long, device = device)
            segment_ids = torch.zeros((num_additional_nodes, max_length_additional), dtype = torch.long, device = device)
            input_mask = torch.zeros((num_additional_nodes, max_length_additional), dtype = torch.long, device = device)
            for i in range(num_additional_nodes):
                length = len(bundle.additional_nodes[i])
                ids[i, :length] = torch.tensor(bundle.additional_nodes[i], dtype = torch.long)
                input_mask[i, :length] = 1
            additional_semantics = model(ids, segment_ids, input_mask)

            semantics = torch.cat((semantics, additional_semantics), dim = 0)

        assert semantics.size()[0] == bundle.adj.size()[0]
        
        if bundle.question_type == 0: # Wh-
            pred = self.gcn(bundle.adj.to(device), semantics)
            ce = torch.nn.CrossEntropyLoss()
            final_loss = ce(pred.unsqueeze(0), torch.tensor([bundle.answer_id], dtype = torch.long, device = device))
        else:
            x, y, ans = bundle.answer_id
            ans = torch.tensor(ans, dtype = torch.float, device = device)
            diff_sem = semantics[x] - semantics[y]
            classifier = self.both_net if bundle.question_type == 1 else self.select_net
            final_loss = 0.2 * torch.nn.functional.binary_cross_entropy_with_logits(classifier(diff_sem).squeeze(-1), ans.to(device))
            # print(ans_loss)
        return hop_loss, ans_loss, final_loss
Esempio n. 2
0
 def gen():
     for batch_num in range(num_batch):
         l, r = batch_num * batch_size, min(
             (batch_num + 1) * batch_size, n)
         yield bundle_part_to_batch(all_bundle, l, r)
Esempio n. 3
0
    def gen():
        for batch_num in range(num_batch):
            l, r = batch_num * batch_size, min((batch_num + 1) * batch_size, n)
            # l and r is the start positon and end position.

            yield bundle_part_to_batch(all_bundle, l, r)
Esempio n. 4
0
    def forward(self, bundle, model, device):
        batch = bundle_part_to_batch(bundle)
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            hop_loss, ans_loss, semantics = model(*batch)
        attention_mask = batch[2]
        # Shape of semantics: [num_para, seq_len, hidden_size]
        # # Shape of semantics: [num_para, hidden_size]
        num_additional_nodes = len(bundle.additional_nodes)

        if num_additional_nodes > 0:
            max_length_additional = max(
                [len(x) for x in bundle.additional_nodes])
            ids = torch.zeros(
                (num_additional_nodes, max_length_additional),
                dtype=torch.long,
                device=device,
            )
            segment_ids = torch.zeros(
                (num_additional_nodes, max_length_additional),
                dtype=torch.long,
                device=device,
            )
            input_mask = torch.zeros(
                (num_additional_nodes, max_length_additional),
                dtype=torch.long,
                device=device,
            )
            for i in range(num_additional_nodes):
                length = len(bundle.additional_nodes[i])
                ids[i, :length] = torch.tensor(bundle.additional_nodes[i],
                                               dtype=torch.long)
                input_mask[i, :length] = 1
            additional_attention_mask = input_mask
            with torch.no_grad():
                additional_semantics = model(ids, segment_ids, input_mask)

            if semantics.shape[1] > additional_semantics.shape[1]:
                zero_shape = list(additional_semantics.shape)
                zero_shape[
                    1] = semantics.shape[1] - additional_semantics.shape[1]
                additional_semantics = torch.cat(
                    (additional_semantics, torch.zeros(zero_shape).to(device)),
                    dim=1)
                additional_attention_mask = torch.cat(
                    (additional_attention_mask,
                     torch.zeros(zero_shape[:-1],
                                 dtype=torch.long).to(device)),
                    dim=1)
            elif semantics.shape[1] < additional_semantics.shape[1]:
                zero_shape = list(semantics.shape)
                zero_shape[
                    1] = additional_semantics.shape[1] - semantics.shape[1]
                semantics = torch.cat(
                    (semantics, torch.zeros(zero_shape).to(device)), dim=1)
                attention_mask = torch.cat(
                    (attention_mask,
                     torch.zeros(zero_shape[:-1],
                                 dtype=torch.long).to(device)),
                    dim=1)
            semantics = torch.cat((semantics, additional_semantics), dim=0)
            attention_mask = torch.cat(
                (attention_mask, additional_attention_mask), dim=0)
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            attention_mask = (1.0 - attention_mask) * -10000.0
            attention_mask = attention_mask.to(
                dtype=torch.float32)  # fp16 compatibility

        assert semantics.size()[0] == bundle.adj.size()[0]
        assert semantics.shape[0] == attention_mask.shape[0]

        if bundle.question_type == 0:  # Wh-
            if self.module_type == "mlp":
                pred = self.gcn(semantics[:, 0]).squeeze(-1)
            elif self.module_type == "gcn":
                pred = self.gcn(bundle.adj.to(device),
                                semantics[:, 0])  #, attention_mask)
            elif self.module_type == "xattn":
                pred = self.gcn(bundle.adj.to(device), semantics,
                                attention_mask)
            else:
                raise NotImplementedError
            ce = torch.nn.CrossEntropyLoss()
            final_loss = ce(
                pred.unsqueeze(0),
                torch.tensor([bundle.answer_id],
                             dtype=torch.long,
                             device=device),
            )
        else:
            x, y, ans = bundle.answer_id
            ans = torch.tensor(ans, dtype=torch.float, device=device)
            diff_sem = semantics[x][0] - semantics[y][0]
            classifier = self.both_net if bundle.question_type == 1 else self.select_net
            final_loss = 0.2 * torch.nn.functional.binary_cross_entropy_with_logits(
                classifier(diff_sem).squeeze(-1), ans.to(device))
        return hop_loss, ans_loss, final_loss