def forward(self, sent1, sent2, labels=None): """Notation straight from paper this time""" a_, b_ = self.input(sent1, sent2) # Attention F_a_ = self.attend(a_) F_b_ = self.attend(b_) e = F_a_.dot('hidden', F_b_) alpha = e.softmax(dim='seqlenA').dot('seqlenA', a_) beta = e.softmax(dim='seqlenB').dot('seqlenB', b_) # Comparison v1 = self.compare(ntorch.cat([a_, beta], 'embedding')) v2 = self.compare(ntorch.cat([b_, alpha], 'embedding')) # Aggregation v1 = v1.sum('seqlenA') v2 = v2.sum('seqlenB') output = self.aggregate(ntorch.cat([v1, v2], 'hidden')) if self.use_labels: assert labels is not None y = ntorch.tensor(labels.values.unsqueeze(1), names=('batch', 'hidden')).cuda() output = self.labelled_output( ntorch.cat([output, y.float()], 'hidden')) y_hat = self.output(output) return y_hat, F_a_, F_b_
def forward(self, text): embeddings = ntorch.cat( [self.lut(text), self.static_lut(text)], 'embedding').transpose('embedding', 'seqlen') feature_list = [ conv_block(embeddings).relu().max("seqlen")[0] for conv_block in self.conv_blocks ] hidden = ntorch.cat(feature_list, "embedding") preds = self.proj(self.dropout(hidden)).log_softmax("classes") return preds
def forward(self, text, aa_info): ''' Pass in context for the next amino acid ''' # Reset for each new batch... h_0 = ntorch.zeros(text.shape["batch"], self.num_layers, self.hiddenlen, names=("batch", "layers", "hiddenlen")).to(self.device) c_0 = ntorch.zeros(text.shape["batch"], self.num_layers, self.hiddenlen, names=("batch", "layers", "hiddenlen")).to(self.device) # If we should use all the sequence as input if self.teacher_force_prob == 1: text_embedding = self.embedding(text) hidden_states, (h_n, c_n) = self.LSTM(text_embedding, (h_0, c_0)) output = self.linear_dropout(hidden_states) output = ntorch.cat([output, aa_info], dim="hiddenlen") output = self.linear(output) # If we should use some combination of teacher forcing else: # Use for teacher forcing... outputs = [] model_input = text[{"seqlen" : slice(0, 1)}] h_n, c_n = h_0, c_0 for position in range(text.shape["seqlen"]): text_embedding = self.embedding(model_input) hidden_states, (h_n, c_n) = self.LSTM(text_embedding, (h_n, c_n)) output = self.linear_dropout(hidden_states) aa_info_subset = aa_info[{"seqlen" : slice(position, position+1)}] output = ntorch.cat([output, aa_info_subset], dim="hiddenlen") output = self.linear(output) outputs.append(output) # Define next input... if random.random() < self.teacher_force_prob: model_input = text[{"seqlen" : slice(position, position+1)}] else: # Masking output... mask_targets = text[{"seqlen" : slice(position, position+1)}].clone() if position == 0: mask_targets[{"seqlen" : 0}] = TEXT.vocab.stoi["<start>"] mask_bad_codons = ntorch.tensor(mask_tbl[mask_targets.values], names=("seqlen", "batch", "vocablen")).float() model_input = (output + mask_bad_codons).argmax("vocablen") # model_input = (output).argmax("vocablen") output = ntorch.cat(outputs, dim="seqlen") return output
def _shift_trg(self, trg): start_of_sent = [[BOS_IND] * trg.shape['batch']] start_of_sent = ntorch.tensor(start_of_sent, names=('trgSeqlen', 'batch')) end_of_sent = trg[{'trgSeqlen': slice(0, trg.shape['trgSeqlen'] - 1)}] shifted = ntorch.cat((start_of_sent, end_of_sent), 'trgSeqlen') return shifted
def forward(self, seq): seq_len = seq.shape["seqlen"] batch_size = seq.shape["batch"] pad_token = self.text.vocab.stoi["<pad>"] additional_padding = ntorch.ones(batch_size, self.longest_n, names=("batch", "seqlen")).long().to(self.device) additional_padding *= pad_token seq = ntorch.cat([additional_padding, seq, additional_padding], dim="seqlen") amino_acids = self.codon_to_aa[seq.values] return_ar = ntorch.zeros(seq_len, batch_size, self.out_vocab, names=("seqlen", "batch", "vocablen")) # convert to numpy to leave GPU amino_acids = amino_acids.detach().cpu().numpy() for batch_item in range(batch_size): # start at n, end at seq_len - n for seq_item in range(self.longest_n, seq_len - self.longest_n): # Must iterate over all dictionaries for weight, n, ngram_dict in zip(self.weight_list, self.n_list, self.dict_list): # N gram is a 2d numpy array containing an amino acid embedding in each row n_gram = amino_acids[batch_item,seq_item - n : seq_item + n + 1] # note, we want to populate the return ar before padding! return_ar[{"seqlen" : seq_item - self.longest_n, "batch" : batch_item}] += weight * ngram_dict[str(n_gram)].float() return return_ar.to(self.device)
def forward(self, x): inputs = [x] # print(x.shape) for layer in self.layers: output = layer(ntorch.cat(inputs, 'h')) inputs.append(output) return inputs[-1]
def make_n_gram_dict(train_iter, n, amino_acid_conversion, TEXT, AA_LABEL): ''' Helper function to create a frequency default dictionary Args: train_iter: Training bucket iterator n: Number of amino acids to each side of AA (e.g. 0 is unigram, 1 is trigram) amino_acid_conversion: index_table converting the codon index to AA index TEXT: torchtext field for the vocab of nucleotides AA_LABEL: Torchtext for amino acids Returns: default_dict: dictionary mapping a sequence of amino acids to probability over codons TODO: Make this faster ''' default_obj = lambda : torch.tensor(np.zeros(len(TEXT.vocab.stoi))) default_dict = defaultdict(default_obj) with torch.no_grad(): ident_mat = np.eye(len(TEXT.vocab.stoi)) ident_mat_aa = np.eye(len(AA_LABEL.vocab)) for i, batch in enumerate(train_iter): # Select for all non zero tensors # Use this to find all indices that aren't padding seq_len = batch.sequence.shape["seqlen"] batch_size = batch.sequence.shape["batch"] # Pad amino acids and seq with <pad> token pad_token = TEXT.vocab.stoi["<pad>"] additional_padding = ntorch.ones(batch_size, n, names=("batch", "seqlen")).long() additional_padding *= pad_token seq = ntorch.cat([additional_padding, batch.sequence, additional_padding], dim="seqlen") # Now one hots.. amino_acids = amino_acid_conversion[seq.values].detach().cpu().numpy() # Note: we should assert that start and pad are treated the same # This is because at test time, presumably we narrow the start for the AA.. if i == 0: assert((amino_acids[0,n] == amino_acids[0,0]).all()) seq = seq.detach().cpu().numpy() # Pad with padding token for batch_item in range(batch_size): # start at n, end at seq_len - n for seq_item in range(n, seq_len - n): # Middle token is a discrete number representing the codon (0 to 66) middle_token = seq[batch_item, seq_item] # N gram is a 2d numpy array containing an amino acid embedding in each row n_gram = amino_acids[batch_item,seq_item - n : seq_item + n + 1] default_dict[str(n_gram)][middle_token] += 1 for key in default_dict: default_dict[key] /= (default_dict[key]).sum() return default_dict
def test(epoch): model.eval() test_loss = 0 with torch.no_grad(): for i, (data, _) in enumerate(test_loader): data = data.to(device) data = NamedTensor(data, ("batch", "ch", "height", "width")) recon_batch, normal = model(data) test_loss += loss_function(recon_batch, data, normal).item() if i == 0: n = min(data.size("batch"), 8) group = [ data.narrow("batch", 0, n), recon_batch.split(x=("ch", "height", "width"), height=28, width=28).narrow("batch", 0, n), ] comparison = ntorch.cat(group, "batch") save_image( comparison.values.cpu(), "results/reconstruction_" + str(epoch) + ".png", nrow=n, ) test_loss /= len(test_loader.dataset) print("====> Test set loss: {:.4f}".format(test_loss))
def forward(self, text): embeddings = ntorch.cat([self.lut(text).sum('seqlen'), \ self.static_lut(text).sum('seqlen')], 'embedding') embeddings = self.dropout(embeddings) hidden = self.proj1(embeddings).relu() preds = self.proj2(hidden).log_softmax('classes') return preds
def decode_one_step(self, t, output_seq, score, state, enc_out): if self.attention: def attend(x_t): alpha = enc_out.dot("rnnOutput", x_t).softmax("srcSeqlen") context = alpha.dot("srcSeqlen", enc_out) return context h, c = state[-1] next_input = output_seq[{"trgSeqlen": slice(t, t + 1)}].long() x_t, (h, c) = self.decoder(self.out_embedding(next_input), (h, c)) if self.attention: fc = self.fc(ntorch.cat([attend(x_t), x_t], dim="rnnOutput")) else: fc = self.fc(x_t) fc = fc.sum("trgSeqlen") fc = fc.log_softmax("outVocab") state = x_t, (h, c) #can instead use argmax ... #next_tokens = fc.argmax("") #ntorch.tensor(topk, names=dim_names) #max, argmax = fc.topk("dim2", k) k = 100 _, argmax = fc.topk("outVocab", k) #print("argmax", argmax) lst = [] for i in range(k): #TODO fix this line or whatever import copy output_seq = copy.deepcopy(output_seq) output_seq[{ "trgSeqlen": t + 1 }] = argmax[{ "outVocab": i }] #TODO fix this line or whatever next_token = output_seq[{"trgSeqlen": slice(t + 1, t + 2)}].long() indices = next_token.sum("trgSeqlen").rename("batch", "indices") batch_indices = ntorch.tensor( torch.tensor(np.arange(fc.shape["batch"]), device=device), ("batchIndices")) newsc = fc.index_select("outVocab", indices).index_select( "indices", batch_indices).get("batchIndices", 0) score[{"trgSeqlen": t + 1}] = newsc assert output_seq[{ "trgSeqlen": t + 1 }].long() == next_token.sum("trgSeqlen") #todo #output_dists[{"trgSeqlen":t+1}] = fc lst.append((output_seq, score, state)) return lst
def log_pc(self, rnn_o, ec, tc, vc): # rnn_o = f(y<t) # condition on soft attention # rW_a = g(r, a) = r[a] # already concatenated if input feed return self.Wcopy( rnn_o if self.inputfeed else self.Wif(ntorch.cat([rnn_o, ec, tc, vc], "rnns")) ).log_softmax("copy")
def intra_attn_layer(self, x, seqlen_dimname): temp_dim = seqlen_dimname + 'temp' x_hidden1 = self.feedforward_intra_attn(x) x_hidden2 = x_hidden1.rename(seqlen_dimname, temp_dim) intra_attn = x_hidden1.dot('hidden', x_hidden2).softmax(temp_dim) self.intra_attn = intra_attn + self.get_distance_bias_matrix( intra_attn.shape[seqlen_dimname], seqlen_dimname, temp_dim) x_aligned = intra_attn.dot(temp_dim, x) return ntorch.cat([x, x_aligned], 'embedding')
def visualize_g(model, test_iter, place_cells, hd_cells, offset=0, limit=50): """Visualize 25 cells in G layer of model (applied to output of LSTM)""" model.eval() G, P = None, None c = 0 # Get batches up to limit as samples for traj in test_iter: cs, hs, ego_vel, c0, h0, xs = get_batch(traj, place_cells, hd_cells, pos=True) if c > limit: break zs, gs, ys = model(ego_vel, c0, h0) if G is None: G = gs.cpu() P = xs.cpu() else: G = ntorch.cat((G, gs.cpu()), "batch") P = ntorch.cat((P, xs.cpu()), "batch") del ego_vel, cs, xs, hs, zs, ys, gs, h0, c0 torch.cuda.empty_cache() c += 1 pts = P.stack(("t", "batch"), "pts") G = G.stack(("t", "batch"), "pts") xs, ys = [pts.get("ax", i).values.detach().numpy() for i in [0, 1]] # Plot 5x5 grid of cell activations, starting at offset axs = plt.subplots(5, 5, figsize=(50, 50))[1] axs = axs.flatten() for i, ax in enumerate(axs): acts = G.get("placecell", offset + i).values.detach().numpy() res = stats.binned_statistic_2d(xs, ys, acts, bins=20, statistic="mean")[0] ax.imshow(res, cmap="jet") ax.axis("off") plt.show()
def forward(self, text, aa_info): ''' Pass in context for the next amino acid ''' # Reset for each new batch... h_0 = ntorch.zeros(text.shape["batch"], self.num_layers, self.hiddenlen, names=("batch", "layers", "hiddenlen")).to(self.device) c_0 = ntorch.zeros(text.shape["batch"], self.num_layers, self.hiddenlen, names=("batch", "layers", "hiddenlen")).to(self.device) # If we should use all the sequence as input if self.teacher_force_prob == 1: text_embedding = self.embedding(text) hidden_states, (h_n, c_n) = self.LSTM(text_embedding, (h_0, c_0)) output = self.linear_dropout(hidden_states) output = ntorch.cat([output, aa_info], dim="hiddenlen") output = self.linear(output) # If we should use some combination of teacher forcing else: # Use for teacher forcing... outputs = [] model_input = text[{"seqlen" : slice(0, 1)}] h_n, c_n = h_0, c_0 for position in range(text.shape["seqlen"]): text_embedding = self.embedding(model_input) hidden_states, (h_n, c_n) = self.LSTM(text_embedding, (h_n, c_n)) output = self.linear_dropout(hidden_states) aa_info_subset = aa_info[{"seqlen" : slice(position, position+1)}] output = ntorch.cat([output, aa_info_subset], dim="hiddenlen") output = self.linear(output) outputs.append(output) # Define next input... if random.random() < self.teacher_force_prob: model_input = text[{"seqlen" : slice(position, position+1)}] else: # TODO: Should we be masking this output? model_input = output.argmax("vocablen") output = ntorch.cat(outputs, dim="seqlen") return output
def forward(self, x): x = self.embedding(x).transpose("h", "slen") x_list = [ conv_block(x).relu().max("slen")[0] for conv_block in self.conv_blocks ] out = ntorch.cat(x_list, "h") feature_extracted = out out = self.fc(self.dropout(out)).softmax("classes") return out, feature_extracted
def forward(self, premise, hypothesis, target): prem = self.embedding(premise) prem = self.lstm_prem(prem)[0][{'seqlen': -1}] hyp = self.embedding(hypothesis) hyp = self.lstm_hyp(hyp)[0][{'seqlen': -1}] tar = ntorch.tensor(target.values.reshape(1, -1).to(torch.float32), names=('embedding', 'batch')) flat = ntorch.cat([hyp, prem, tar], 'embedding') out = self.linear(flat).log_softmax('logprob') return out
def log_pv_x(self, e, t): query = self.queryproj( ntorch.cat( [ self.lute(e).rename("e", "et"), self.lutt(t).rename("t", "et"), ], "et", )) log_pv = self.vproj(query).log_softmax("v") return log_pv
def forward(self, seq, c0, h0): cells = ntorch.cat([h0, c0], ["hdcell", "placecell"], name="cells") initial_state = (self.init_cell(cells), self.init_state(cells)) out, _ = self.rnn(seq, initial_state) g = F.dropout( self.g(out).transpose("batch", "g", "t").values, 0.5, self.training) g = ntorch.tensor(g, names=("batch", "g", "t")) return self.head(g), self.place(g), g
def forward(self, chars, masks, last_actions): x = self.encoder(chars, masks) x = self.pooling(x) la = self.action_embedding(last_actions) x = ntorch.cat([x, la], 'h') x = self.fc(x).relu() x = self.action_decoder(x) if self.is_value_net: x = x._new(F.log_softmax(x._tensor, dim=x._schema.get('value'))) return x
def forward(self, a, b, y=None): a_bar, b_bar = self.input(a, b) # ATTEND F_a = self.f(a_bar) F_b = self.f(b_bar) e_mat = F_a.dot('hidden', F_b) alpha = e_mat.softmax(dim='aSeqlen').dot('aSeqlen', a_bar) beta = e_mat.softmax(dim='bSeqlen').dot('bSeqlen', b_bar) # COMPARE AND AGGREGATE v1 = self.g(ntorch.cat([a_bar, beta], 'embedding')).sum('aSeqlen') v2 = self.g(ntorch.cat([b_bar, alpha], 'embedding')).sum('bSeqlen') # NOTE: currently adds log softmax layer after linear, use nllloss out = self.h(ntorch.cat([v1, v2], 'hidden')) if self.use_labels: y = ntorch.tensor(y.values.unsqueeze(1), names=('batch', 'hidden')) out = self.y_combine(ntorch.cat([out, y.float()], 'hidden')) yhat = self.final(out) return yhat
def forward(self, a, b, show_attn=False): """ The inputs are vectors, for now a: batch x seqlenA x embedding b: batch x seqlenB x embedding """ a = self.embedding(a).rename("seqlen", "seqlenA") b = self.embedding(b).rename("seqlen", "seqlenB") if args.intra_sentence: #we ignore distance bias term because we are lazy a_p = a.dot("embedding", a.rename("seqlenA", "sl")).softmax("sl").dot( "sl", a.rename("seqlenA", "sl")) b_p = b.dot("embedding", b.rename("seqlenB", "sl")).softmax("sl").dot( "sl", b.rename("seqlenB", "sl")) a = ntorch.cat((a, a_p), "embedding") b = ntorch.cat((b, b_p), "embedding") a = self.F(a).relu() b = self.F(b).relu() attns_alpha = a.dot("embedding", b).softmax("seqlenA") attns_beta = b.dot("embedding", a).softmax("seqlenB") if show_attn: return attns_alpha, attns_beta alpha = attns_alpha.dot("seqlenA", a) beta = attns_beta.dot("seqlenB", b) v1 = self.G(ntorch.cat((a, beta), "embedding")).relu().sum("seqlenA") v2 = self.G(ntorch.cat((b, alpha), "embedding")).relu().sum("seqlenB") y = self.H(ntorch.cat((v1, v2), "embedding")) return y
def forward(self, premise, hypothesis): premise = self.embedding(premise) premise = self.embedding_projection(premise).rename( 'seqlen', 'seqlenPremise') hypothesis = self.embedding(hypothesis) hypothesis = self.embedding_projection(hypothesis).rename( 'seqlen', 'seqlenHypo') if self.intra_attn: premise = self.intra_attn_layer(premise, 'seqlenPremise') hypothesis = self.intra_attn_layer(hypothesis, 'seqlenHypo') premise_mask = (premise != 0).float() hypothesis_mask = (hypothesis != 0).float() #attend premise_hidden = self.feedforward_attn(premise) hypothesis_hidden = self.feedforward_attn(hypothesis) self.attn = premise_hidden.dot('hidden', hypothesis_hidden) alpha = self.attn.softmax('seqlenHypo').dot('seqlenPremise', premise) beta = self.attn.softmax('seqlenPremise').dot('seqlenHypo', hypothesis) #mask alpha = alpha * hypothesis_mask beta = beta * premise_mask #compare hypothesis_comparison = self.feedforward_aligned( ntorch.cat([alpha, hypothesis], 'embedding')).sum('seqlenHypo') premise_comparison = self.feedforward_aligned( ntorch.cat([beta, premise], 'embedding')).sum('seqlenPremise') #aggregate agg = ntorch.cat([premise_comparison, hypothesis_comparison], 'hidden') agg = self.feedforward_agg(agg) agg = self.final_linear(agg) return agg
def forward(self, x, s, x_info, r, r_info, ue, ue_info, ut, ut_info, v2d): emb = self.lutx(x) N = emb.shape["batch"] T = emb.shape["time"] e = self.lute(r[0]).rename("e", "r") t = self.lutt(r[1]).rename("t", "r") v = self.lutv(r[2]).rename("v", "r") # r: R x N x Er, Wa r: R x N x H r = self.Wa(ntorch.cat([e, t, v], dim="r").tanh()) if not self.inputfeed: # rnn_o: T x N x H rnn_o, s = self.rnn(emb, s, x_info.lengths) # ea: T x N x R _, ea, ec = attn(rnn_o, r, r_info.mask) if self.noattn: ec = r.mean("els").repeat("time", ec.shape["time"]) self.ea = ea out = self.Wc(ntorch.cat([rnn_o, ec], "rnns")).tanh() else: out = [] ect = NamedTensor( torch.zeros(N, self.r_emb_sz).to(emb.values.device), names=("batch", "rnns"), ) for t in range(T): inp = ntorch.cat([emb.get("time", t), ect.rename("rnns", "x")], "x").repeat("time", 1) rnn_o, s = self.rnn(inp, s) rnn_o = rnn_o.get("time", 0) _, eat, ect = attn(rnn_o, r, r_info.mask) out.append(ntorch.cat([rnn_o, ect], "rnns")) out = self.Wc(ntorch.stack(out, "time")).tanh() # return unnormalized vocab return self.proj(self.drop(out)), s
def forward(self, src, trg, teacher_forcing=None, beam_width=1, beam_len=3, num_candidates=1): if beam_width > 1: return self.beam(src, trg, beam_width, beam_len, num_candidates) if teacher_forcing is None: teacher_forcing = self.teacher_forcing #get src encoding hidden = self.encoder(src) # initialize outputs output_tokens = [trg[{'trgSeqlen': slice(0, 1)}]] output_distributions = [] attn = [] # make predictions for t in range(trg.shape['trgSeqlen'] - 1): #predict next word if random.random() < teacher_forcing: inp = trg[{'trgSeqlen': slice(t, t + 1)}] out, hidden = self.decoder(inp, hidden) else: out, hidden = self.decoder(output_tokens[t], hidden) out = out.log_softmax('logit') #store output if 'attn' in hidden: attn.append(hidden['attn']) output_distributions.append(out) _, top1 = out.max("logit") output_tokens.append(top1) #format predictions return (ntorch.cat(output_distributions, dim='trgSeqlen'), ntorch.cat(attn, dim='trgSeqlen'))
def forward(self, x): #print("x shape1", x.shape) x = self.embedding(x).transpose("h", "seqlen") #print("x shape2", x.shape) x_list = [ conv_block(x).relu().max("seqlen")[0] for conv_block in self.conv_blocks ] out = ntorch.cat(x_list, "features") #print("out shape", out.shape) feature_extracted = out out = self.fc(self.dropout(out)).log_softmax("classes") #print("out2 shape", out.shape) #assert False return out
def forward(self, chars, masks): chars_emb = self.char_embedding(chars) chars_emb = chars_emb.stack(('charEmb', 'stateLoc'), 'inFeatures') # print(chars_emb.shape) # print(masks.shape) x = ntorch.cat([chars_emb, masks], 'inFeatures') # print("after add masks", x.shape) e = self.column_encoding(x) # print("after column encoding", e.shape) e = e.stack(('strLen', 'expression'), 'h') # print(e.shape) h = self.MLP(e) return h
def forward(self, x): # x: (batch, seqlen) x = x.augment(self.embedding, "h") \ .transpose("h", "seqlen") x_list = [ x.op(conv_block, F.relu).max("seqlen")[0] for conv_block in self.conv_blocks ] out = ntorch.cat(x_list, "h") feature_extracted = out drop = lambda x: F.dropout(x, p=0.5, training=self.training) out = out.op(drop, self.fc, classes="h") \ .softmax("classes") return out, feature_extracted
def forward(self, x): # import ipdb; ipdb.set_trace() # y = self.Wb #y = (self.W.index_select('vocab', x.long()).sum('vocab') + self.b).sigmoid() # y_ = self.U(self.dropout(self.V(x)).relu().sum("seqlen")).sum('score').sigmoid() x = self.embedding(x).transpose("h", "seqlen") x_list = [ conv_block(x).relu().max("seqlen")[0] for conv_block in self.convs ] # y_ = self.U self.embedding(x).sum("seqlen") out = ntorch.cat(x_list, "h") # feature_extracted = out out = self.fc(self.dropout(out)).softmax("classes") # y_ = self.relu(self.W(x)).sigmoid().sum('singular') # this is a huge hack # y = ntorch.stack([y_, 1.-y_], 'classes') #.log_softmax('classes') return out
def forward(self, trg, hidden): # get hidden state src = hidden['src'] rnn_state = hidden['rnn_state'] if 'rnn_state' in hidden else None #run net x = self.embedding(trg) x = self.dropout(x) if rnn_state is not None: x, rnn_state = self.rnn(x, rnn_state) else: x, rnn_state = self.rnn(x) attn = x.dot('lstm', src).softmax('srcSeqlen') context = attn.dot('srcSeqlen', src) x = self.out(ntorch.cat([context, x], dim='lstm')) # create new hidden state hidden = {'src': src, 'rnn_state': rnn_state, 'attn': attn} return x, hidden
def align(self, s): intra_s = self.intra_layers(s) intra_s_ = intra_s.values.transpose(0, 1) # formatting batch_seq = torch.bmm(intra_s_, intra_s_.transpose(1, 2)) batches = batch_seq.shape[0] seqlen = batch_seq.shape[1] align_matrix = torch.tensor([[(i - j) for j in range(seqlen)] for i in range(seqlen)]) align_matrix = torch.clamp(align_matrix, -self.cap, self.cap) align_matrix = align_matrix.unsqueeze(0).expand( batches, seqlen, seqlen) # batch * seqlen * seqlen align_matrix_b = self.bias[align_matrix + self.cap] weights = torch.softmax(align_matrix_b + batch_seq, dim=2) s_ = torch.matmul(weights, s.values.transpose(0, 1)) s_ = ntorch.tensor(s_, ('batch', 'seqlen', 'embedding')) return ntorch.cat([s, s_], 'embedding')