Exemplo n.º 1
0
 def forward(self, src_emb, batch):
     # encoding
     src_mask_inv = (batch["src_mask"] == 0).to(device)  # [batch, seq]
     # src_mask_inv = batch["src_mask_inv"]
     # print('src_emb:', src_emb.size())
     # assert len(src_emb.size()) == 3, 'Invalid input size:{}, should be 3!!'.format(' x '.join([str(itm) for itm in src_emb.size()]))
     structure_emb = self.relation_embedding(
         batch["wr"].to(device))  # [batch, seq, seq, s_dim]
     assert len(structure_emb.size(
     )) == 4, 'Invalid input size:{}, should be 3!!'.format(' x '.join(
         [str(itm) for itm in structure_emb.size()]))
     assert src_emb.size(1) == structure_emb.size(1) and src_emb.size(
         1) == structure_emb.size(2)
     # print('structure size:', structure_emb.size())
     # print("input_size", src_emb.size())
     if self.use_pe:
         src_emb = self.position_encoder(src_emb.transpose(0, 1)).transpose(
             0, 1)  # [batch, seq, dim]
     # print('src_emb_pos:', src_emb.size())
     src_emb = self.encoder(src_emb,
                            src_key_padding_mask=src_mask_inv,
                            structure=structure_emb)  # [batch, seq, dim]
     # exit()
     assert has_nan(src_emb) is False
     return src_emb
Exemplo n.º 2
0
 def forward(self, batch):
     # encoding
     src_mask_inv = (batch["src_mask"] == 0).to(device)  # [batch, seq]
     # batch["src_mask_inv"] = src_mask_inv
     src_emb = self.embedding(batch["src"].to(device))  # [batch, seq, dim]
     # print("input_size", history_emb.size())
     src_emb = self.position_encoder(src_emb.transpose(0, 1)).transpose(
         0, 1)  # [batch, seq, dim]
     src_emb = self.encoder(
         src_emb, src_key_padding_mask=src_mask_inv)  # [batch, seq, dim]
     assert has_nan(src_emb) is False
     return src_emb
Exemplo n.º 3
0
 def forward(self, batch):
     # encoding
     src_mask_inv = batch["con_mask"] == 0  # [batch, seq]
     batch["con_mask_inv"] = src_mask_inv
     src_emb = self.embedding(batch["con"])  # [batch, seq, dim]
     structure_emb = self.relation_embedding(
         batch["rel"])  # [batch, seq, seq, s_dim]
     # print('structure size:', structure_emb.size())
     # print("input_size", src_emb.size())
     if self.use_pe:
         src_emb = self.position_encoder(src_emb.transpose(0, 1)).transpose(
             0, 1)  # [batch, seq, dim]
     src_emb = self.encoder(src_emb,
                            src_key_padding_mask=src_mask_inv,
                            structure=structure_emb)  # [batch, seq, dim]
     batch["con_emb"] = src_emb
     assert has_nan(src_emb) is False
     return src_emb
Exemplo n.º 4
0
 def forward(self, src_emb, src_length, relations):
     # src_emb:  [batch, seq, dim]
     # print("input_size", src_emb.size())
     bsz, max_seq_len, hid_size = src_emb.size()
     # print('max_seq_len:', max_seq_len)
     src_mask = torch.from_numpy(len_to_mask(src_length, max_seq_len)).to(
         src_emb.device)
     src_mask_inv = src_mask == 0
     relations = relations.squeeze(1)
     structure_emb = self.relation_embedding(
         relations)  # [batch, seq, seq, s_dim]
     # print('structure size:', structure_emb.size())
     if self.use_pe:
         src_emb = self.position_encoder(src_emb.transpose(0, 1)).transpose(
             0, 1)  # [batch, seq, dim]
     src_emb = self.encoder(src_emb,
                            src_key_padding_mask=src_mask_inv,
                            structure=structure_emb)  # [batch, seq, dim]
     assert has_nan(src_emb) is False
     return src_emb
Exemplo n.º 5
0
    def decode(
        self,
        sent_memory_emb,
        graph_memory_emb,
        sent_memory_mask,
        graph_memory_mask,
        beamsize,
        max_step,
    ):  # [batch, seq, dim]
        batch_size, sent_memory_seq, s_dim = list(sent_memory_emb.shape)
        if graph_memory_emb is not None:
            _, graph_memory_seq, g_dim = list(graph_memory_emb.shape)
        beam = [[BeamInstance(ids=[self.BOS], neg_logp=0.0, is_finish=False)]
                for i in range(batch_size)]
        cur_beamsize = 1
        for step in range(max_step):
            cur_seq = step + 1
            target_input = [[beam[i][j].ids for j in range(cur_beamsize)]
                            for i in range(batch_size)
                            ]  # [batch, beam, cur_seq]
            target_input = (torch.tensor(target_input).to(device).view(
                batch_size * cur_beamsize, cur_seq))  # [batch*beam, cur_seq]
            target_emb = self.dec_word_embedding(target_input).transpose(
                0, 1)  # [cur_seq, batch*beam, dim]
            target_emb = self.position_encoder(target_emb).transpose(
                0, 1)  # [batch*beam, cur_seq, dim]

            cur_sent_memory_emb = (sent_memory_emb.unsqueeze(dim=1).repeat(
                1, cur_beamsize, 1,
                1).view(batch_size * cur_beamsize, sent_memory_seq,
                        s_dim))  # [batch*beam, sent_memory_seq, dim]
            cur_sent_memory_mask_inv = (sent_memory_mask.unsqueeze(
                dim=1).repeat(1, cur_beamsize,
                              1).view(batch_size * cur_beamsize,
                                      sent_memory_seq) == 0
                                        )  # [batch*beam, graph_memory_seq]
            if graph_memory_emb is not None:
                cur_graph_memory_emb = (
                    graph_memory_emb.unsqueeze(dim=1).repeat(
                        1, cur_beamsize, 1,
                        1).view(batch_size * cur_beamsize, graph_memory_seq,
                                g_dim))  # [batch*beam, graph_memory_seq, dim]
                cur_graph_memory_mask_inv = (
                    graph_memory_mask.unsqueeze(dim=1).repeat(
                        1, cur_beamsize, 1).view(batch_size * cur_beamsize,
                                                 graph_memory_seq) == 0
                )  # [batch*beam, graph_memory_seq]
            else:
                cur_graph_memory_emb = None
                cur_graph_memory_mask_inv = None
            cur_triu_mask = torch.triu(torch.ones(cur_seq, cur_seq).to(device),
                                       diagonal=1)  # [cur_seq, cur_seq]
            cur_triu_mask = cur_triu_mask.repeat(batch_size * cur_beamsize, 1,
                                                 1)
            cur_triu_mask.masked_fill_(cur_triu_mask == 1, -1e20)

            target_emb = self.decoder(
                target_emb,
                cur_sent_memory_emb,
                cur_graph_memory_emb,
                tgt_mask=cur_triu_mask,
                tgt_key_padding_mask=None,
                sent_memory_key_padding_mask=cur_sent_memory_mask_inv,
                graph_memory_key_padding_mask=cur_graph_memory_mask_inv,
            )  # [batch*beam, cur_seq, dim]

            assert has_nan(target_emb) is False

            # generating step outputs
            logits = self.projector(target_emb[:, -1, :]).view(
                batch_size, cur_beamsize,
                self.word_vocab_size)  # [batch, beam, vocab]
            logits = F.log_softmax(logits, dim=2)
            indices = logits.topk(
                beamsize + 1, dim=2)[1].cpu().numpy()  # [batch, beam, topk]

            all_finish = True
            next_beam = []
            for i in range(batch_size):
                cands = []
                for j in range(cur_beamsize):
                    if beam[i][j].is_finish:
                        cands.append(
                            (0, beam[i][j].neg_logp, self.EOS,
                             j))  # to make sure 'finished' are in the front
                    else:
                        for nid in indices[i, j]:
                            neg_logp = beam[i][j].neg_logp - logits[
                                i, j, nid].item()
                            cands.append((1, neg_logp, int(nid),
                                          j))  # '0' finished; '1' unfinished
                assert len(cands) >= beamsize
                assert sum([x[0] == 0 for x in cands]) <= beamsize, cands
                cands.sort()

                next_beam.append([])
                for _, neg_logp, nid, j in cands[:beamsize]:
                    is_finish = beam[i][j].is_finish or nid == self.EOS
                    all_finish &= is_finish
                    next_instance = BeamInstance(ids=beam[i][j].ids + [
                        nid,
                    ],
                                                 neg_logp=neg_logp,
                                                 is_finish=is_finish)
                    next_beam[-1].append(next_instance)

            # preparing for the next loop
            if all_finish:
                break

            beam = next_beam
            cur_beamsize = beamsize

        best = []
        for i in range(batch_size):
            indices = np.argsort([x.get_logp_norm(self.EOS) for x in beam[i]])
            j = indices[0]
            best.append(beam[i][j].get_ids(self.EOS))
        return best
Exemplo n.º 6
0
    def forward(self, batch):
        # encoding
        sent_mask_inv = batch["src_mask"] == 0  # [batch, seq]
        graph_mask_inv = batch["con_mask"] == 0
        sent_mem = self.word_encoder(batch)
        graph_mem = self.graph_encoder(
            batch) if self.graph_encoder is not None else None
        # decoding, batch['target'] includes both <bos> and <eos>
        if batch["tgt_ref"] is not None:
            target_input = batch["tgt_input"]  # [batch, trg_seq]
            target_ref = batch["tgt_ref"]  # [batch, trg_seq]
            bsz, trg_seq = target_input.size()
            triangle_mask = torch.triu(torch.ones(trg_seq, trg_seq).to(device),
                                       diagonal=1)  # [trg_seq, trg_seq]
            triangle_mask.masked_fill_(triangle_mask == 1, -1e20)
            triangle_mask = triangle_mask.repeat(bsz, 1, 1)
            target_mask_inv = batch["tgt_mask"] == 0  # [batch, trg_seq]

            target_emb = self.dec_word_embedding(target_input).transpose(
                0, 1)  # [batch, trg_seq, dim]
            target_emb = self.position_encoder(target_emb).transpose(
                0, 1)  # [batch, trg_seq, dim]
            # print('tgt_emb_size:', target_emb.size())
            # print('memory_size:', combined_emb.size())
            target_emb = self.decoder(
                target_emb,
                sent_mem,
                graph_mem,
                tgt_mask=triangle_mask,
                tgt_key_padding_mask=target_mask_inv,
                sent_memory_key_padding_mask=sent_mask_inv,
                graph_memory_key_padding_mask=graph_mask_inv,
            )  # [batch, trg_seq, dim]

            assert has_nan(target_emb) is False
            # generating outputs
            logits = self.projector(target_emb)  # [batch, trg_seq, vocab]
            preds = logits.argmax(dim=2)  # [batch, trg_seq]
            loss = F.cross_entropy(
                logits.contiguous().view(-1, self.word_vocab_size),
                target_ref.contiguous().view(-1),
                ignore_index=0,
            )
            train_right = ((preds == target_ref).float() *
                           batch["tgt_mask"]).sum()
            train_total = batch["tgt_mask"].sum()

            return {
                "preds": preds,
                "loss": loss,
                "counts": (train_right, train_total),
                "selected_kn": None,
                "trg_selected_kn": None,
            }
        else:
            return {
                "sent_memory_emb": sent_mem,
                "sent_memory_mask": batch["src_mask"],
                "graph_memory_emb": graph_mem,
                "graph_memory_mask": batch["con_mask"],
                "selected_kn": None,
                "trg_selected_kn": None,
            }
Exemplo n.º 7
0
    def inference(
        self,
        sent_memory_emb,
        graph_memory_emb,
        sent_memory_mask,
        graph_memory_mask,
        max_step,
        use_sampling=False,
    ):
        batch_size, sent_memory_seq, dim = list(sent_memory_emb.shape)
        _, graph_memory_seq, _ = list(graph_memory_emb.shape)

        sent_memory_mask_inv = sent_memory_mask == 0  # [batch, sent_memory_seq]
        graph_memory_mask_inv = graph_memory_mask == 0  # [batch, sent_memory_seq]

        target_ids = [[self.BOS
                       for i in range(batch_size)]]  # [target_seq, batch]
        target_mask = [[1.0] for i in range(batch_size)]  # [batch, target_seq]
        target_prob = []  # [target_seq, batch]
        is_finish = [False for _ in range(batch_size)]
        rows = torch.arange(batch_size).to(device)
        for step in range(max_step):
            cur_seq = step + 1
            cur_emb = self.dec_word_embedding(
                torch.tensor(target_ids).to(device))  # [cur_seq, batch, dim]
            cur_emb = self.position_encoder(cur_emb)  # [cur_seq, batch, dim]

            cur_mask = torch.tensor(target_mask).to(device)
            cur_mask_inv = cur_mask == 0.0  # [batch, cur_seq]
            cur_triu_mask = torch.triu(torch.ones(cur_seq, cur_seq).to(device),
                                       diagonal=1)  # [cur_seq, cur_seq]
            cur_triu_mask.masked_fill_(cur_triu_mask == 1, -1e20)

            cur_emb = self.decoder(
                cur_emb,
                sent_memory_emb,  # [batch, sent_len, dim]
                graph_memory_emb,  # [batch, graph_len, dim]
                tgt_mask=cur_triu_mask,
                tgt_key_padding_mask=cur_mask_inv,
                sent_memory_key_padding_mask=sent_memory_mask_inv,
                graph_memory_key_padding_mask=graph_memory_mask_inv,
            )  # [batch, cur_seq, dim]

            assert has_nan(cur_emb) is False

            # break after the first time when all items are finished
            if all(is_finish) or step == max_step - 1:
                cur_len = cur_mask.sum(dim=1).long()
                target_vec = universal_sentence_embedding(
                    cur_emb, cur_mask, cur_len)
                break

            # generating step outputs
            logits = self.projector(cur_emb[:, -1, :]).view(
                batch_size, self.word_vocab_size)  # [batch, vocab]
            if use_sampling is False:
                indices = logits.argmax(dim=1)  # [batch]
            else:
                indices = Categorical(logits=logits).sample()  # [batch]

            prob = F.softmax(logits, dim=1)[rows, indices]  # [batch]
            target_prob.append(prob)
            indices = indices.cpu().tolist()
            target_ids.append(indices)
            for i in range(batch_size):
                target_mask[i].append(
                    0.0 if is_finish[i] else
                    1.0)  # based on if is_finish in the last step

            for i in range(batch_size):
                is_finish[i] |= indices[i] == self.EOS

        target_ids = list(map(list,
                              zip(*target_ids[1:])))  # [batch, target_seq]
        target_mask = torch.tensor([x[1:] for x in target_mask
                                    ]).to(device)  # [batch, target_seq]
        target_prob = torch.stack(target_prob, dim=1)  # [batch, target_seq]
        return target_vec, target_ids, target_prob, target_mask