예제 #1
0
	def forward(self, text, aa_info):
		''' 
		  Pass in context for the next amino acid
		'''
		
		# Reset for each new batch...
		h_0 = ntorch.zeros(text.shape["batch"], self.num_layers, self.hiddenlen, 
							names=("batch", "layers", "hiddenlen")).to(self.device)
		c_0 = ntorch.zeros(text.shape["batch"], self.num_layers, self.hiddenlen, 
							names=("batch", "layers", "hiddenlen")).to(self.device)
	 
		# If we should use all the sequence as input
		if self.teacher_force_prob == 1: 
		  text_embedding = self.embedding(text)
		  hidden_states, (h_n, c_n) = self.LSTM(text_embedding, (h_0, c_0))
		  output = self.linear_dropout(hidden_states)
		  output = ntorch.cat([output, aa_info], dim="hiddenlen")
		  output = self.linear(output)
		
		# If we should use some combination of teacher forcing
		else: 
			# Use for teacher forcing...
			outputs = []
			model_input = text[{"seqlen" : slice(0, 1)}]
			h_n, c_n = h_0, c_0
			for position in range(text.shape["seqlen"]): 
				text_embedding = self.embedding(model_input)
				hidden_states, (h_n, c_n) = self.LSTM(text_embedding, (h_n, c_n))

				output = self.linear_dropout(hidden_states)
				aa_info_subset = aa_info[{"seqlen" : slice(position, position+1)}]
				output = ntorch.cat([output, aa_info_subset], dim="hiddenlen")
				output = self.linear(output)
				outputs.append(output)

				# Define next input... 
				if random.random() < self.teacher_force_prob: 
					model_input = text[{"seqlen" : slice(position, position+1)}]
				else: 
					# Masking output... 
					mask_targets = text[{"seqlen" : slice(position, position+1)}].clone()
					if position == 0: 
						mask_targets[{"seqlen" : 0}] = TEXT.vocab.stoi["<start>"]
					mask_bad_codons = ntorch.tensor(mask_tbl[mask_targets.values], 
						names=("seqlen", "batch", "vocablen")).float()

					model_input = (output + mask_bad_codons).argmax("vocablen")
					# model_input = (output).argmax("vocablen")
			  
			output = ntorch.cat(outputs, dim="seqlen")
		return output
예제 #2
0
	def forward(self, seq): 
		'''
		Forward pass
		''' 
		aa_rep = self.aa_embed(seq)    
		h_0 = ntorch.zeros(self.num_layers * self.num_directions, aa_rep.shape["batch"], self.hiddenlen, 
							names=("layers", "batch", "hiddenlen")).to(self.device)
		c_0 = ntorch.zeros(self.num_layers * self.num_directions, aa_rep.shape["batch"], self.hiddenlen, 
							names=("layers", "batch", "hiddenlen")).to(self.device)
		
		h_0 = h_0.transpose("batch", "layers", "hiddenlen")
		c_0 = c_0.transpose("batch", "layers", "hiddenlen")
		hidden_states, (h_n, c_n) = self.LSTM(aa_rep, (h_0, c_0))
		return hidden_states
예제 #3
0
 def init_state(self, N):
     # what's this for?
     if self._N != N:
         self._N = N
         self._state = (
             ntorch.zeros(
                 self.nlayers, N, self.rnn_sz,
                 names=("layers", "batch", "rnns"),
             ).to(self.lutx.weight.device),
             ntorch.zeros(
                 self.nlayers, N, self.rnn_sz,
                 names=("layers", "batch", "rnns"),
             ).to(self.lutx.weight.device),
         )
     return self._state
예제 #4
0
	def forward(self, seq): 
		seq_len = seq.shape["seqlen"]
		batch_size = seq.shape["batch"]
		  
		pad_token = self.text.vocab.stoi["<pad>"]
		additional_padding = ntorch.ones(batch_size, self.longest_n, 
										names=("batch", "seqlen")).long().to(self.device)
		additional_padding *= pad_token
		
		seq = ntorch.cat([additional_padding, seq, additional_padding],
						dim="seqlen")
		
		
		amino_acids = self.codon_to_aa[seq.values]
		
		return_ar = ntorch.zeros(seq_len, batch_size, self.out_vocab,
								 names=("seqlen", "batch", "vocablen"))
		
		# convert to numpy to leave GPU 
		amino_acids = amino_acids.detach().cpu().numpy()
		for batch_item in range(batch_size): 
		  # start at n, end at seq_len - n
			for seq_item in range(self.longest_n, seq_len - self.longest_n):
				# Must iterate over all dictionaries
				for weight, n, ngram_dict in zip(self.weight_list,
												self.n_list, self.dict_list):
					# N gram is a 2d numpy array containing an amino acid embedding in each row
					n_gram = amino_acids[batch_item,seq_item - n : seq_item + n + 1]

					# note, we want to populate the return ar before padding!
					return_ar[{"seqlen" : seq_item - self.longest_n, 
							 "batch" : batch_item}] += weight * ngram_dict[str(n_gram)].float()

		return return_ar.to(self.device)
예제 #5
0
def evaluate(model, batches):
    model.eval()
    with torch.no_grad():
        loss_fn = ntorch.nn.NLLLoss(reduction='sum').spec('label')
        total_loss = 0
        total_num = 0
        num_correct = 0
        if args.algo == "vae":
            identity = ntorch.NamedTensor(torch.eyes(len(LABEL.vocab)), names=("index", "label"))
            q_sum = ntorch.zeros(len(LABEL.vocab), model.K, names=("label", "model"))
        for i, batch in enumerate(batches):
            log_probs, e  = model.forward(
                batch.premise, batch.hypothesis)
            preds = log_probs.argmax('label')
            total_loss += loss_fn(log_probs, batch.label).item()
            num_correct += get_correct(preds, batch.label)
            total_num += len(batch)

            if args.algo == "vae":
                q_sum = q_sum + identity.index_select("index", batch.label).dot("batch", e)
            if args.algo == "attn" and args.visualize_freq and i % args.visualize_freq == 0:
                fname = './img/' + args.save_img
                visualize_attn(e[{'batch': 0}], batch.premise[{'batch': 0}], batch.hypothesis[{'batch': 0}], save_name=fname)

        if args.algo == "vae":
            q_sum = q_sum / q_sum.mean("model")
            print(q_sum._tensor.cpu().data.values)
        return total_loss / total_num, 100.0 * num_correct / total_num
예제 #6
0
	def forward(self, text, aa_info):
		''' 
		  Pass in context for the next amino acid
		'''
		
		# Reset for each new batch...
		h_0 = ntorch.zeros(text.shape["batch"], self.num_layers, self.hiddenlen, 
							names=("batch", "layers", "hiddenlen")).to(self.device)
		c_0 = ntorch.zeros(text.shape["batch"], self.num_layers, self.hiddenlen, 
							names=("batch", "layers", "hiddenlen")).to(self.device)
	 
		# If we should use all the sequence as input
		if self.teacher_force_prob == 1: 
		  text_embedding = self.embedding(text)
		  hidden_states, (h_n, c_n) = self.LSTM(text_embedding, (h_0, c_0))
		  output = self.linear_dropout(hidden_states)
		  output = ntorch.cat([output, aa_info], dim="hiddenlen")
		  output = self.linear(output)
		
		# If we should use some combination of teacher forcing
		else: 
			# Use for teacher forcing...
			outputs = []
			model_input = text[{"seqlen" : slice(0, 1)}]
			h_n, c_n = h_0, c_0
			for position in range(text.shape["seqlen"]): 
				text_embedding = self.embedding(model_input)
				hidden_states, (h_n, c_n) = self.LSTM(text_embedding, (h_n, c_n))

				output = self.linear_dropout(hidden_states)
				aa_info_subset = aa_info[{"seqlen" : slice(position, position+1)}]
				output = ntorch.cat([output, aa_info_subset], dim="hiddenlen")
				output = self.linear(output)
				outputs.append(output)

				# Define next input... 
				if random.random() < self.teacher_force_prob: 
					model_input = text[{"seqlen" : slice(position, position+1)}]
				else: 
					# TODO: Should we be masking this output?
					model_input = output.argmax("vocablen")
			  
			output = ntorch.cat(outputs, dim="seqlen")
		return output
예제 #7
0
	def forward(self, seq): 
		'''
		Forward pass
		''' 
		# Replace start codon...
		seq_copy = seq.clone()
		seq_copy[{"seqlen" : 0}] = self.start_index
		seq = seq_copy

		aa_rep = self.aa_embed(seq)    
		h_0 = ntorch.zeros(self.num_layers * self.num_directions, aa_rep.shape["batch"], self.hiddenlen, 
							names=("layers", "batch", "hiddenlen")).to(self.device)
		c_0 = ntorch.zeros(self.num_layers * self.num_directions, aa_rep.shape["batch"], self.hiddenlen, 
							names=("layers", "batch", "hiddenlen")).to(self.device)
		
		h_0 = h_0.transpose("batch", "layers", "hiddenlen")
		c_0 = c_0.transpose("batch", "layers", "hiddenlen")
		hidden_states, (h_n, c_n) = self.LSTM(aa_rep, (h_0, c_0))
		return hidden_states
예제 #8
0
파일: vae.py 프로젝트: wfus/namedtensor
def loss_function(recon_x, x, var_posterior):
    BCE = recon_x.reduce2(
        x.stack(h=("ch", "height", "width")),
        lambda x, y: F.binary_cross_entropy(x, y, reduction="sum"),
        ("batch", "x"),
    )
    prior = ndistributions.Normal(ntorch.zeros(dict(batch=1, z=1)),
                                  ntorch.ones(dict(batch=1, z=1)))
    KLD = ndistributions.kl_divergence(var_posterior, prior).sum()
    return BCE + KLD
예제 #9
0
 def pe(self):
     pe = ntorch.zeros(
         MAX_LEN, self.d_model, names=(self.dim_length, self.dim_hidden)
     )
     position = ntorch.arange(0, MAX_LEN, names=self.dim_length).float()
     shift = ntorch.arange(0, self.d_model, 2, names=self.dim_hidden)
     div_term = ntorch.exp(
         shift.float() * -(math.log(10000.0) / self.d_model)
     )
     val = ntorch.mul(position, div_term)
     print(val.shape, shift.shape, pe.shape)
     pe[{self.dim_hidden: shift}] = val.sin()
     pe[{self.dim_hidden: shift + 1}] = val.cos()
     return pe
예제 #10
0
파일: sol_clean.py 프로젝트: emtseng/cs6741
    def __init__(self, vocab, num_classes, padding_idx):
        super(LR, self).__init__()
        vocab_size = len(vocab.itos)
        self.lut = ntorch.nn.Embedding(
            vocab_size,
            num_classes,
            padding_idx=padding_idx,
        ).augment('classes')
        #self.bias = ntorch.zeros(dict(classes=num_classes))
        self.bias = ntorch.zeros(num_classes, names=("classes", ))
        self.bias_param = nn.Parameter(self.bias._tensor)
        self.bias._tensor = self.bias_param

        self.loss_fn = ntorch.nn.NLLLoss(reduction='sum') \
                                        .reduce(('batch', 'classes'))
예제 #11
0
    def beam(self, src, trg, k, beam_len, num_candidates):
        batch_size = src.shape['batch']
        out_dists = HypothesisMap(
            device=self.device)  # map a hypothesis to distribution over words
        scores = HypothesisMap(
            keys=[trg[{
                'trgSeqlen': slice(0, 1)
            }]],
            vals=[ntorch.zeros(batch_size, names='batch')],
            device=self.device)  # map a hypothesis to its score
        end = HypothesisMap(
            device=self.device)  # special buffer for hyptothesis with <EOS>
        attn = []
        EOS_IND = 3

        hidden = self.encoder(src)

        # make predictions
        for l in range(beam_len or trg.shape['trgSeqlen'] - 1):
            new_scores = HypothesisMap(device=self.device)
            hyps = scores.get_topk(k) if l > 0 else scores
            for hyp, score in hyps.items():
                inp = hyp[{'trgSeqlen': slice(l, l + 1)}]

                out, hidden = self.decoder(inp, hidden)
                out = out.log_softmax('logit')
                topk = out.topk('logit', k)

                for i in range(k):
                    pred_prob = topk[0][{'logit': i, 'trgSeqlen': -1}]
                    pred = topk[1][{'logit': i}]
                    new_hyp = ntorch.cat([hyp, pred], 'trgSeqlen')

                    if hyp in out_dists:
                        out_dists[new_hyp] = ntorch.cat([out_dists[hyp], out],
                                                        'trgSeqlen')
                    else:
                        out_dists[new_hyp] = out

                    if torch.any((pred[{'trgSeqlen': -1}] == EOS_IND).values):
                        end[new_hyp] = score + pred_prob
                        end[new_hyp].masked_fill_(
                            pred[{
                                'trgSeqlen': -1
                            }] != EOS_IND, -float('inf'))
                        pred_prob.masked_fill_(
                            pred[{
                                'trgSeqlen': -1
                            }] == EOS_IND, -float('inf'))
                    new_scores[new_hyp] = score + pred_prob
            scores = new_scores
        for hyp, score in end.items():
            scores[hyp] = score
        best = scores.get_topk(num_candidates).keys
        out = [out_dists[k] for k in best]

        #store output
        if 'attn' in hidden:
            attn.append(hidden['attn'])

        #format predictions
        return ntorch.stack(out, 'candidates'), ntorch.cat(attn,
                                                           dim='trgSeqlen')
예제 #12
0
    def trajectories(self, N=100, dt=0.02):
        perimeter = self.params['perimeter']
        T = self.params["T"]
        n = int(T / dt)
        mu, sigma, b = [
            self.params[i]
            for i in ["mean_rotation", "std_dev_rotation", "std_dev_forward"]
        ]

        rotation_velocities = torch.tensor(
            np.random.normal(mu, sigma, size=(n, N))).float()
        forward_velocities = torch.tensor(np.random.rayleigh(
            b, size=(n, N))).float()

        positions = ntorch.zeros((n, 2, N), names=("t", "ax", "sample"))
        vs = torch.zeros((n, N))
        angles = rotation_velocities
        directions = torch.zeros((n, 2, N))

        vs[0] = self.params["v0"]
        theta = torch.rand(N) * 2 * np.pi
        directions[0] = unit_vector(theta)
        positions[{
            "t": 0
        }] = ntorch.tensor(self.scene.random(N), names=("sample", "ax"))

        for i in range(1, n):
            dist, phi = self.scene.closestWall(positions[{
                "t": i - 1
            }].values, directions[i - 1])
            wall = (dist < perimeter) & (phi.abs() < np.pi / 2)
            angle_correction = torch.where(
                wall,
                phi.sign() * (np.pi / 2 - phi.abs()), torch.zeros_like(phi))
            angles[i] += angle_correction

            vs[i] = torch.where(
                wall,
                (1 - self.params["velocity_reduction"]) * (vs[i - 1]),
                forward_velocities[i],
            )
            positions[{
                "t": i
            }] = (positions[{
                "t": i - 1
            }] + directions[i - 1] * vs[i] * dt)

            mat = rotation_matrix(angles[i] * dt)
            directions[i] = torch.einsum("ijk,jk->ik", mat, directions[i - 1])

        idx = np.round(np.linspace(
            0, n - 2, self.params["trajectory_length"])).astype(int)
        # idx = np.array(sorted(np.random.choice(np.arange(n), size=self.params["trajectory_length"], replace=False)))

        dphis = ntorch.tensor(angles[idx] * dt, names=("t", "sample"))
        velocities = ntorch.tensor(vs[idx], names=("t", "sample"))
        vel = ntorch.stack((velocities, dphis.cos(), dphis.sin()), "input")

        xs = ntorch.tensor(positions.values[idx], names=("t", "ax", "sample"))
        # xs0 = positions[{'t': 0}]
        xs0 = ntorch.tensor(self.scene.random(n=N), names=("sample", "ax"))

        hd = torch.atan2(directions[:, 1], directions[:, 0])
        hd0 = ntorch.tensor(hd[0][None], names=("hd", "sample"))
        hd = ntorch.tensor(hd[idx + 1][None], names=("hd", "t", "sample"))

        xs = xs.transpose('sample', 't', 'ax')
        hd = hd.transpose('sample', 't', 'hd')
        vel = vel.transpose('sample', 't', 'input')
        xs0 = xs0.transpose('sample', 'ax')
        hd0 = hd0.transpose('sample', 'hd')

        return xs, hd, vel, xs0, hd0
예제 #13
0
    def forward(self,
                source,
                target=None,
                teacher_forcing=1.,
                max_length=20,
                encode_only=False):
        if target:
            max_length = target.shape["trgSeqlen"]
        x = self.in_embedding(source)
        out, (h, c) = self.encoder(x)
        h = ntorch.cat((h[{
            "layers": slice(0, 1)
        }], h[{
            "layers": slice(1, 2)
        }]),
                       dim="rnnOutput")
        c = ntorch.cat((c[{
            "layers": slice(0, 1)
        }], c[{
            "layers": slice(1, 2)
        }]),
                       dim="rnnOutput")

        if self.attention:

            def attend(x_t):
                alpha = out.dot("rnnOutput", x_t).softmax("srcSeqlen")
                context = alpha.dot("srcSeqlen", out)
                return context

        batch_size = source.shape["batch"]
        output_dists = ntorch.zeros(
            (batch_size, max_length, self.out_vocab_size),
            names=("batch", "trgSeqlen", "outVocab"),
            device=device)
        output_seq = ntorch.zeros((batch_size, max_length),
                                  names=("batch", "trgSeqlen"),
                                  device=device)
        #for the above, should set zeroith index to SOS

        score = ntorch.zeros((batch_size, max_length),
                             names=("batch", "trgSeqlen"),
                             device=device)

        if encode_only:
            return score, out, (h, c), output_seq

        for t in range(max_length - 1):  #Oh god
            if t == 0:
                # always start with SOS token
                next_input = ntorch.ones((batch_size, 1),
                                         names=("batch", "trgSeqlen"),
                                         device=device).long()
                next_input *= EN.vocab.stoi["<s>"]
            elif np.random.random(
            ) < teacher_forcing and target:  # we will force
                next_input = target[{"trgSeqlen": slice(t, t + 1)}]
            else:
                next_input = sample

            x_t, (h, c) = self.decoder(self.out_embedding(next_input), (h, c))

            if t == 0:
                syntax_out, (s_h, s_c) = self.syntax_decoder(
                    self.out_embedding(next_input))
            else:
                syntax_out, (s_h, s_c) = self.syntax_decoder(
                    self.out_embedding(next_input), (s_h, s_c))

            if self.attention:
                fc = self.fc(ntorch.cat([attend(x_t), x_t], dim="rnnOutput"))
            else:
                fc = self.fc(x_t)

            s_fc = self.syntax_fc(syntax_out).sum("trgSeqlen")
            s_fc = s_fc.log_softmax("outVocab")

            dist = ntorch.distributions.Categorical(logits=fc,
                                                    dim_logit="outVocab")
            sample = dist.sample()

            fc = fc.sum("trgSeqlen")

            next_token = (sample) if not target else target[{
                "trgSeqlen":
                slice(t + 1, t + 2)
            }]  #TODO

            #this is the line where the syntax thing does it's stuff
            fc = fc.log_softmax("outVocab") + s_fc

            indices = next_token.sum("trgSeqlen").rename("batch", "indices")
            batch_indices = ntorch.tensor(
                torch.tensor(np.arange(fc.shape["batch"]), device=device),
                ("batchIndices"))

            newsc = fc.index_select("outVocab", indices).index_select(
                "indices", batch_indices).get("batchIndices", 0)

            score[{"trgSeqlen": t + 1}] = newsc

            output_seq[{
                "trgSeqlen": t + 1
            }] = next_token.sum("trgSeqlen")  #todo
            output_dists[{"trgSeqlen": t + 1}] = fc  #Todo

        return output_seq, output_dists, score
예제 #14
0
파일: crnnlma.py 프로젝트: justinchiu/genie
    def forward(self,
                x,
                s,
                x_info,
                r,
                r_info,
                vt,
                ue,
                ue_info,
                ut,
                ut_info,
                v2d,
                vt2d,
                y=None):
        emb = self.lutx(x)
        N = emb.shape["batch"]
        T = emb.shape["time"]

        e = self.lute(r[0]).rename("e", "r")
        t = self.lutt(r[1]).rename("t", "r")
        v = self.lutv(r[2]).rename("v", "r")
        # r: R x N x Er, Wa r: R x N x H
        r = self.War(ntorch.cat([e, t, v], dim="r").tanh())
        eA = self.Wae(self.lute(ue))
        tA = self.Wat(self.lutt(ut))
        if self.v2d:
            v2dx = self.lutv(v2d.stack(
                ("t", "e"), "els")).chop("els", ("t", "e"),
                                         t=v2d.shape["t"]).rename("v", "rnns")
        else:
            # vt2dx
            v2dx = self.lutx(vt2d.stack(
                ("t", "e"),
                "time")).chop("time", ("t", "e"),
                              t=v2d.shape["t"]).rename("x", "rnns")

        if not self.inputfeed:
            # rnn_o: T x N x H
            rnn_o, s = self.rnn(emb, s, x_info.lengths)
            # ea: T x N x R
            log_e, ea_E, ec_E = attn(rnn_o, eA, ue_info.mask)
            #log_t, ea_T, ec_T = attn(rnn_o + ec_E, tA, ut_info.mask)
            log_t, ea_T, ec_T = attn(rnn_o, tA, ut_info.mask)
            if self.noattn:
                ec = r.mean("els").repeat("time", ec.shape["time"])
            ec_ET = ec_E + ec_T
            le = log_e.rename("els", "e")
            lt = log_t.rename("els", "t")
            self.ea = le
            self.ta = lt
            aw = (le + lt).exp()
            ec = aw.dot(("t", "e"), v2dx)
            self.a = aw
            # no ent or typ
            #out = self.Wc(ntorch.cat([rnn_o, ec], "rnns")).tanh()
            # cat ent and type, this seems fine
            out = (self.Wc_nov(ntorch.cat([rnn_o, ec_E, ec_T], "rnns"))
                   if self.noattnvalues else self.Wc(
                       ntorch.cat([rnn_o, ec, ec_E, ec_T], "rnns"))).tanh()
            # add ent and typ
            #out = self.Wc(ntorch.cat([rnn_o, ec + ec_ET], "rnns")).tanh()
        else:
            out = []
            self.ea = []
            self.ta = []
            self.a = []
            ec_ETt = ntorch.zeros(N, self.r_emb_sz,
                                  names=("batch",
                                         "rnns")).to(emb.values.device)
            for t in range(T):
                ec_ETt = ec_ETt.rename("rnns", "x")
                inp = ntorch.cat([emb.get("time", t), ec_ETt],
                                 "x").repeat("time", 1)
                rnn_o, s = self.rnn(inp, s)
                rnn_o = rnn_o.get("time", 0)
                log_e, ea_Et, ec_Et = attn(rnn_o, eA, ue_info.mask)
                log_t, ea_Tt, ec_Tt = attn(rnn_o, tA, ut_info.mask)
                ec_ETt = ec_Et + ec_Tt
                le = log_e.rename("els", "e")
                lt = log_t.rename("els", "t")
                aw = (le + lt).exp()
                ect = aw.dot(("t", "e"), v2dx)
                out.append(ntorch.cat([rnn_o, ect, ec_Et, ec_Tt], "rnns"))
                self.ea.append(ea_Et.detach())
                self.ta.append(ea_Tt.detach())
                self.a.append(aw.detach())
            out = self.Wc(ntorch.stack(out, "time")).tanh()

        # return unnormalized vocab
        return self.proj(self.drop(out)), s
예제 #15
0
    def pa0(self, emb_x, s, x_info, emb_e, ue_info, emb_t, ut_info, v2dx):
        T = emb_x.shape["time"]
        N = emb_x.shape["batch"]

        log_ea, ea, ec = None, None, None
        log_ta, ta, tc = None, None, None
        log_a, a, c = None, None, None
        output = None

        if not self.inputfeed:
            # rnn_o: T x N x H
            rnn_o, s = self.rnn(emb_x, s, x_info.lengths)
            # ea: T x N x R
            log_ea, ea, ec = attn(rnn_o, emb_e, ue_info.mask)
            #log_t, ea_T, ec_T = attn(rnn_o + ec_E, tA, ut_info.mask)
            log_ta, ta, tc = attn(rnn_o, emb_t, ut_info.mask)
            if self.noattn:
                ec = r.mean("els").repeat("time", ec.shape["time"])
            log_ea = log_ea.rename("els", "e")
            log_ta = log_ta.rename("els", "t")
            log_va = log_ea + log_ta
            vc = log_va.exp().dot(("t", "e"), v2dx)
            va = log_va.exp()

            ea = ea.rename("els", "e")
            ta = ta.rename("els", "t")

            output = rnn_o
        else:
            log_ea, ea, ec = [], [], []
            log_ta, ta, tc = [], [], []
            log_va, va, vc = [], [], []
            out = []
            etc_t = ntorch.zeros(
                N, self.r_emb_sz, names=("batch", "rnns")
            ).to(emb_x.values.device)
            for t in range(T):
                etc_t = etc_t.rename("rnns", "x")
                inp = ntorch.cat([emb_x.get("time", t), etc_t], "x").repeat("time", 1)
                rnn_o, s = self.rnn(inp, s)
                rnn_o = rnn_o.get("time", 0)
                log_ea_t, ea_t, ec_t = attn(rnn_o, emb_e, ue_info.mask)
                log_ta_t, ta_t, tc_t = attn(rnn_o, emb_t, ut_info.mask)
                log_ea_t = log_ea_t.rename("els", "e")
                log_ta_t = log_ta_t.rename("els", "t")
                log_va_t = log_ea_t + log_ta_t
                va_t = log_va_t.exp()
                vc_t = va_t.dot(("t", "e"), v2dx)
                out.append(
                    self.Wif(ntorch.cat([rnn_o, vc_t, ec_t, tc_t], "rnns"))
                )

                log_ea.append(log_ea_t)
                ea.append(ea_t)
                ec.append(ec_t)
                log_ta.append(log_ta_t)
                ta.append(ta_t)
                tc.append(tc_t)
                log_va.append(log_va_t)
                va.append(va_t)
                vc.append(vc_t)

            output = ntorch.stack(out, "time")

            log_ea = ntorch.stack(log_ea, "time")
            ea = ntorch.stack(ea, "time")
            ec = ntorch.stack(ec, "time")
            log_ta = ntorch.stack(log_ta, "time")
            ta = ntorch.stack(ta, "time")
            tc = ntorch.stack(tc, "time")
            log_va = ntorch.stack(log_va, "time")
            va = ntorch.stack(va, "time")
            vc = ntorch.stack(vc, "time")

            ea = ea.rename("els", "e")
            ta = ta.rename("els", "t")

        return log_ea, ea, ec, log_ta, ta, tc, log_va, va, vc, output, s