def forward(self, inputs, mask): input_vecs = self.embed(inputs.long()) outputs, _ = model_utils.get_rnn_vecs(input_vecs, mask, self.lstm, bidir=True) sent_vec = self.attn(inputs, mask, outputs, self.weight_embed) return sent_vec
def forward(self, inputs, mask, temp=None): input_vecs = self.dropout(self.embed(inputs.long())) outputs, _ = model_utils.get_rnn_vecs(input_vecs, mask, self.lstm, bidir=True) outputs = self.dropout(outputs) * mask.unsqueeze(-1) sent_vec = outputs.sum(1) / mask.sum(1, keepdim=True) return input_vecs, sent_vec
def get_vecs(self, inputs, mask): input_vecs = self.embed(inputs.long()) outputs, _ = model_utils.get_rnn_vecs(input_vecs, mask, self.lstm, bidir=True) outputs = outputs * mask.unsqueeze(-1) sent_vec = outputs.sum(1) / mask.sum(1, keepdim=True) return sent_vec, outputs
def get_vecs(self, inputs, mask): input_vecs = self.embed(inputs.long()) outputs, _ = model_utils.get_rnn_vecs(input_vecs, mask, self.lstm, bidir=True) outputs = outputs * mask.unsqueeze(-1) sent_vec = self.attn(inputs, mask, outputs, self.weight_embed) return sent_vec, outputs
def forward(self, inputs, mask): input_vecs = self.embed(inputs.long()) outputs, _ = model_utils.get_rnn_vecs(input_vecs, mask, self.lstm, bidir=True) # code.interact(local=locals()) outputs = outputs * mask.unsqueeze(-1) sent_vec = outputs.sum(1) / mask.sum(1, keepdim=True) return sent_vec
def forward(self, inputs, mask, temp=None): input_vecs = self.dropout(self.embed(inputs.long())) lc_vecs = [] for proj, emb in zip(self.lc, self.lc_embed): prob = F.softmax(proj(input_vecs), -1) lc_vecs.append(torch.matmul(prob, emb.weight)) input_vecs = self.dropout(torch.cat(lc_vecs, -1)) outputs, _ = model_utils.get_rnn_vecs(input_vecs, mask, self.lstm, bidir=True) outputs = self.dropout(outputs) * mask.unsqueeze(-1) sent_vec = outputs.sum(1) / mask.sum(1, keepdim=True) return input_vecs, sent_vec
def pred(self, init_state, tgts, tgts_mask): bs, sl = tgts_mask.size() inp_tgts = self.drop_word(tgts) tgts_embed = self.embed(inp_tgts.long()) output_seq, _ = model_utils.get_rnn_vecs(tgts_embed, tgts_mask, self.cell, initial_state=init_state) # batch size x seq len x vocab size if not self.tie_weight: pred = self.hid2vocab(output_seq)[:, :-1, :] else: pred = torch.matmul(output_seq, self.embed.weight.t())[:, :-1, :] return pred
def pred(self, yvecs, zvecs, tgts, tgts_mask): bs, sl = tgts_mask.size() tgts_embed = self.dropout(self.embed(tgts.long())) ex_input_vecs = zvecs.unsqueeze(1).expand(-1, sl, -1) ex_output_vecs = yvecs.unsqueeze(1).expand(-1, sl, -1) input_vecs = torch.cat([tgts_embed, ex_input_vecs], -1) ori_output_seq, _ = model_utils.get_rnn_vecs(input_vecs, tgts_mask, self.cell, bidir=False, initial_state=None) output_seq = torch.cat([ori_output_seq, ex_output_vecs], -1) # batch size x seq len x vocab size pred = self.hid2vocab(self.dropout(output_seq))[:, :-1, :] return pred, input_vecs
def forward(self, yvecs, zvecs, tgts, tgts_mask, *args, **kwargs): bs, sl = tgts_mask.size() ex_input_vecs = zvecs.unsqueeze(1).expand(-1, sl, -1) ex_output_vecs = yvecs.unsqueeze(1).expand(-1, sl, -1) ori_output_seq, _ = model_utils.get_rnn_vecs(ex_input_vecs, tgts_mask, self.cell) output_seq = torch.cat([ori_output_seq, ex_output_vecs], -1) # batch size x seq len x vocab size pred = self.hid2vocab(output_seq)[:, :-1, :] batch_size, seq_len, vocab_size = pred.size() pred = pred.contiguous().view(batch_size * seq_len, vocab_size) logloss = F.cross_entropy(pred, tgts[:, 1:].contiguous().view(-1).long(), reduce=False) logloss = (logloss.view(batch_size, seq_len) * tgts_mask[:, 1:]).sum(-1) / tgts_mask[:, 1:].sum(-1) return logloss.mean()
def pred(self, yvecs, zvecs, tgts, tgts_mask): bs, sl = tgts_mask.size() tgts_embed = self.dropout(self.embed(tgts.long())) init_vecs = self.latent2init(torch.cat([yvecs, zvecs], -1)) if isinstance(self.cell, nn.LSTM): init_vecs = tuple([ h.unsqueeze(0).contiguous() for h in torch.chunk(init_vecs, 2, -1) ]) input_vecs = tgts_embed ori_output_seq, _ = model_utils.get_rnn_vecs(input_vecs, tgts_mask, self.cell, bidir=False, initial_state=init_vecs) output_seq = ori_output_seq # batch size x seq len x vocab size pred = self.hid2vocab(self.dropout(output_seq))[:, :-1, :] return pred, torch.cat( [tgts_embed, zvecs.unsqueeze(1).expand(-1, sl, -1)], -1)