def bienc(self, input, enc): context, enc_h, M = enc(input) context_rev, enc_h_rev, M_rev = enc(util.flip(input, 0)) # dnc memory also bidirectional if isinstance(self.encoder, dnc.DNC): context_rev, enc_h_rev, M_rev = enc(util.flip(input, 0), M) ''' if self.encoder.layers == 2: h_out = ((self.bd_h(torch.cat([context[0], enc_h_rev[0][0]], 1)), self.bd_h(torch.cat([context[0], enc_h_rev[0][1]], 1))), (self.bd_h(torch.cat([context[0], enc_h_rev[0][0]], 1)), self.bd_h(torch.cat([context[0], enc_h_rev[0][1]], 1)))) elif self.encoder.layers == 1: h_out = (self.bd_h(torch.cat([context[0], enc_h_rev[0][0]], 1)), self.bd_h(torch.cat([context[0], enc_h_rev[0][1]], 1))) ''' context_out = self.bd_context( torch.cat((util.flip(context, dim=0), context_rev), 2).view(-1, 2 * context.size(2))) if isinstance(self.encoder, nse.NSE): M_out = self.bd_m( torch.cat((util.flip(M, dim=1), M_rev), 2).view(-1, 2 * M.size(2))).view(*M.size()) else: M_out = M_rev return context_out.view(*context.size()), enc_h_rev, M_out
def nse_n2n(self, input): if self.brnn: context, enc_h, enc_M = self.bienc(input[0][0], self.nse_enc) else: context, enc_h, enc_M = self.nse_enc(util.flip(input[0][0])) mask = input[0][0].t().eq(0).detach() M = self.embed_A(enc_M) C = self.embed_C(enc_M) emb_out = self.embed_out(input[1][:-1]) # u = Variable(emb_out.data.new(*emb_out.size() # [1:]).zero_() + .1, requires_grad=True) u = enc_h[0][0] outputs = [] for w in emb_out.split(1): dec_in = torch.cat((w.squeeze(), u), 1) u = self.n2n_cat_feed(dec_in) out, U, O = self.decoder(u, M, C, mask) outputs += [out[0]] return torch.stack(outputs)
def nse_nse(self, input): if self.brnn: context, enc_h, enc_M = self.bienc(input[0][0], self.nse_enc) else: context, enc_h, enc_M = self.nse_enc(util.flip(input[0][0])) emb_out = self.embed_out(input[1][:-1]) dec_M = enc_M.detach() mask = util.flip(input[0][0]).transpose(0, 1).eq(0).detach() # self.update_queue(M) outputs, _, _ = self.decoder(emb_out, enc_h, (dec_M, mask), None) # self.update_queue(M, mask) return outputs
def nse_dnc(self, input): if self.brnn: context, enc_h, enc_M = self.bienc(input[0][0], self.nse_enc) else: context, enc_h, enc_M = self.nse_enc(util.flip(input[0][0])) emb_out = self.embed_out(input[1][:-1]) init_M = self.decoder.make_init_M(emb_out.size(1)) out, dec_hidden, mem = self.decoder(emb_out, enc_h, init_M, context) return out
def nse_lstm(self, input): if self.brnn: context, enc_h, M = self.bienc(input[0][0], self.nse_enc) else: context, enc_h, M = self.nse_enc(util.flip(input[0][0])) hidden = (torch.stack( (enc_h[0][0], enc_h[1][0])), torch.stack( (enc_h[0][1], enc_h[1][1]))) init_output = self.make_init_hidden(enc_h[0][0], 1)[0] out, dec_hidden, _attn = self.decoder(input[1][:-1], hidden, context, init_output) return out