コード例 #1
0
 def forward(self, inputs, mask):
     input_vecs = self.embed(inputs.long())
     outputs, _ = model_utils.get_rnn_vecs(input_vecs,
                                           mask,
                                           self.lstm,
                                           bidir=True)
     sent_vec = self.attn(inputs, mask, outputs, self.weight_embed)
     return sent_vec
コード例 #2
0
 def forward(self, inputs, mask, temp=None):
     input_vecs = self.dropout(self.embed(inputs.long()))
     outputs, _ = model_utils.get_rnn_vecs(input_vecs,
                                           mask,
                                           self.lstm,
                                           bidir=True)
     outputs = self.dropout(outputs) * mask.unsqueeze(-1)
     sent_vec = outputs.sum(1) / mask.sum(1, keepdim=True)
     return input_vecs, sent_vec
コード例 #3
0
 def get_vecs(self, inputs, mask):
     input_vecs = self.embed(inputs.long())
     outputs, _ = model_utils.get_rnn_vecs(input_vecs,
                                           mask,
                                           self.lstm,
                                           bidir=True)
     outputs = outputs * mask.unsqueeze(-1)
     sent_vec = outputs.sum(1) / mask.sum(1, keepdim=True)
     return sent_vec, outputs
コード例 #4
0
 def get_vecs(self, inputs, mask):
     input_vecs = self.embed(inputs.long())
     outputs, _ = model_utils.get_rnn_vecs(input_vecs,
                                           mask,
                                           self.lstm,
                                           bidir=True)
     outputs = outputs * mask.unsqueeze(-1)
     sent_vec = self.attn(inputs, mask, outputs, self.weight_embed)
     return sent_vec, outputs
コード例 #5
0
 def forward(self, inputs, mask):
     input_vecs = self.embed(inputs.long())
     outputs, _ = model_utils.get_rnn_vecs(input_vecs,
                                           mask,
                                           self.lstm,
                                           bidir=True)
     # code.interact(local=locals())
     outputs = outputs * mask.unsqueeze(-1)
     sent_vec = outputs.sum(1) / mask.sum(1, keepdim=True)
     return sent_vec
コード例 #6
0
 def forward(self, inputs, mask, temp=None):
     input_vecs = self.dropout(self.embed(inputs.long()))
     lc_vecs = []
     for proj, emb in zip(self.lc, self.lc_embed):
         prob = F.softmax(proj(input_vecs), -1)
         lc_vecs.append(torch.matmul(prob, emb.weight))
     input_vecs = self.dropout(torch.cat(lc_vecs, -1))
     outputs, _ = model_utils.get_rnn_vecs(input_vecs,
                                           mask,
                                           self.lstm,
                                           bidir=True)
     outputs = self.dropout(outputs) * mask.unsqueeze(-1)
     sent_vec = outputs.sum(1) / mask.sum(1, keepdim=True)
     return input_vecs, sent_vec
コード例 #7
0
    def pred(self, init_state, tgts, tgts_mask):
        bs, sl = tgts_mask.size()
        inp_tgts = self.drop_word(tgts)
        tgts_embed = self.embed(inp_tgts.long())

        output_seq, _ = model_utils.get_rnn_vecs(tgts_embed,
                                                 tgts_mask,
                                                 self.cell,
                                                 initial_state=init_state)
        # batch size x seq len x vocab size
        if not self.tie_weight:
            pred = self.hid2vocab(output_seq)[:, :-1, :]
        else:
            pred = torch.matmul(output_seq, self.embed.weight.t())[:, :-1, :]
        return pred
コード例 #8
0
    def pred(self, yvecs, zvecs, tgts, tgts_mask):
        bs, sl = tgts_mask.size()
        tgts_embed = self.dropout(self.embed(tgts.long()))
        ex_input_vecs = zvecs.unsqueeze(1).expand(-1, sl, -1)
        ex_output_vecs = yvecs.unsqueeze(1).expand(-1, sl, -1)

        input_vecs = torch.cat([tgts_embed, ex_input_vecs], -1)
        ori_output_seq, _ = model_utils.get_rnn_vecs(input_vecs,
                                                     tgts_mask,
                                                     self.cell,
                                                     bidir=False,
                                                     initial_state=None)
        output_seq = torch.cat([ori_output_seq, ex_output_vecs], -1)
        # batch size x seq len x vocab size
        pred = self.hid2vocab(self.dropout(output_seq))[:, :-1, :]
        return pred, input_vecs
コード例 #9
0
    def forward(self, yvecs, zvecs, tgts, tgts_mask, *args, **kwargs):
        bs, sl = tgts_mask.size()
        ex_input_vecs = zvecs.unsqueeze(1).expand(-1, sl, -1)
        ex_output_vecs = yvecs.unsqueeze(1).expand(-1, sl, -1)

        ori_output_seq, _ = model_utils.get_rnn_vecs(ex_input_vecs, tgts_mask,
                                                     self.cell)
        output_seq = torch.cat([ori_output_seq, ex_output_vecs], -1)
        # batch size x seq len x vocab size
        pred = self.hid2vocab(output_seq)[:, :-1, :]

        batch_size, seq_len, vocab_size = pred.size()

        pred = pred.contiguous().view(batch_size * seq_len, vocab_size)
        logloss = F.cross_entropy(pred,
                                  tgts[:, 1:].contiguous().view(-1).long(),
                                  reduce=False)

        logloss = (logloss.view(batch_size, seq_len) *
                   tgts_mask[:, 1:]).sum(-1) / tgts_mask[:, 1:].sum(-1)
        return logloss.mean()
コード例 #10
0
    def pred(self, yvecs, zvecs, tgts, tgts_mask):
        bs, sl = tgts_mask.size()
        tgts_embed = self.dropout(self.embed(tgts.long()))
        init_vecs = self.latent2init(torch.cat([yvecs, zvecs], -1))

        if isinstance(self.cell, nn.LSTM):
            init_vecs = tuple([
                h.unsqueeze(0).contiguous()
                for h in torch.chunk(init_vecs, 2, -1)
            ])

        input_vecs = tgts_embed
        ori_output_seq, _ = model_utils.get_rnn_vecs(input_vecs,
                                                     tgts_mask,
                                                     self.cell,
                                                     bidir=False,
                                                     initial_state=init_vecs)
        output_seq = ori_output_seq
        # batch size x seq len x vocab size
        pred = self.hid2vocab(self.dropout(output_seq))[:, :-1, :]
        return pred, torch.cat(
            [tgts_embed, zvecs.unsqueeze(1).expand(-1, sl, -1)], -1)