def forward(self, text, aa_info): ''' Pass in context for the next amino acid ''' # Reset for each new batch... h_0 = ntorch.zeros(text.shape["batch"], self.num_layers, self.hiddenlen, names=("batch", "layers", "hiddenlen")).to(self.device) c_0 = ntorch.zeros(text.shape["batch"], self.num_layers, self.hiddenlen, names=("batch", "layers", "hiddenlen")).to(self.device) # If we should use all the sequence as input if self.teacher_force_prob == 1: text_embedding = self.embedding(text) hidden_states, (h_n, c_n) = self.LSTM(text_embedding, (h_0, c_0)) output = self.linear_dropout(hidden_states) output = ntorch.cat([output, aa_info], dim="hiddenlen") output = self.linear(output) # If we should use some combination of teacher forcing else: # Use for teacher forcing... outputs = [] model_input = text[{"seqlen" : slice(0, 1)}] h_n, c_n = h_0, c_0 for position in range(text.shape["seqlen"]): text_embedding = self.embedding(model_input) hidden_states, (h_n, c_n) = self.LSTM(text_embedding, (h_n, c_n)) output = self.linear_dropout(hidden_states) aa_info_subset = aa_info[{"seqlen" : slice(position, position+1)}] output = ntorch.cat([output, aa_info_subset], dim="hiddenlen") output = self.linear(output) outputs.append(output) # Define next input... if random.random() < self.teacher_force_prob: model_input = text[{"seqlen" : slice(position, position+1)}] else: # Masking output... mask_targets = text[{"seqlen" : slice(position, position+1)}].clone() if position == 0: mask_targets[{"seqlen" : 0}] = TEXT.vocab.stoi["<start>"] mask_bad_codons = ntorch.tensor(mask_tbl[mask_targets.values], names=("seqlen", "batch", "vocablen")).float() model_input = (output + mask_bad_codons).argmax("vocablen") # model_input = (output).argmax("vocablen") output = ntorch.cat(outputs, dim="seqlen") return output
def forward(self, seq): ''' Forward pass ''' aa_rep = self.aa_embed(seq) h_0 = ntorch.zeros(self.num_layers * self.num_directions, aa_rep.shape["batch"], self.hiddenlen, names=("layers", "batch", "hiddenlen")).to(self.device) c_0 = ntorch.zeros(self.num_layers * self.num_directions, aa_rep.shape["batch"], self.hiddenlen, names=("layers", "batch", "hiddenlen")).to(self.device) h_0 = h_0.transpose("batch", "layers", "hiddenlen") c_0 = c_0.transpose("batch", "layers", "hiddenlen") hidden_states, (h_n, c_n) = self.LSTM(aa_rep, (h_0, c_0)) return hidden_states
def init_state(self, N): # what's this for? if self._N != N: self._N = N self._state = ( ntorch.zeros( self.nlayers, N, self.rnn_sz, names=("layers", "batch", "rnns"), ).to(self.lutx.weight.device), ntorch.zeros( self.nlayers, N, self.rnn_sz, names=("layers", "batch", "rnns"), ).to(self.lutx.weight.device), ) return self._state
def forward(self, seq): seq_len = seq.shape["seqlen"] batch_size = seq.shape["batch"] pad_token = self.text.vocab.stoi["<pad>"] additional_padding = ntorch.ones(batch_size, self.longest_n, names=("batch", "seqlen")).long().to(self.device) additional_padding *= pad_token seq = ntorch.cat([additional_padding, seq, additional_padding], dim="seqlen") amino_acids = self.codon_to_aa[seq.values] return_ar = ntorch.zeros(seq_len, batch_size, self.out_vocab, names=("seqlen", "batch", "vocablen")) # convert to numpy to leave GPU amino_acids = amino_acids.detach().cpu().numpy() for batch_item in range(batch_size): # start at n, end at seq_len - n for seq_item in range(self.longest_n, seq_len - self.longest_n): # Must iterate over all dictionaries for weight, n, ngram_dict in zip(self.weight_list, self.n_list, self.dict_list): # N gram is a 2d numpy array containing an amino acid embedding in each row n_gram = amino_acids[batch_item,seq_item - n : seq_item + n + 1] # note, we want to populate the return ar before padding! return_ar[{"seqlen" : seq_item - self.longest_n, "batch" : batch_item}] += weight * ngram_dict[str(n_gram)].float() return return_ar.to(self.device)
def evaluate(model, batches): model.eval() with torch.no_grad(): loss_fn = ntorch.nn.NLLLoss(reduction='sum').spec('label') total_loss = 0 total_num = 0 num_correct = 0 if args.algo == "vae": identity = ntorch.NamedTensor(torch.eyes(len(LABEL.vocab)), names=("index", "label")) q_sum = ntorch.zeros(len(LABEL.vocab), model.K, names=("label", "model")) for i, batch in enumerate(batches): log_probs, e = model.forward( batch.premise, batch.hypothesis) preds = log_probs.argmax('label') total_loss += loss_fn(log_probs, batch.label).item() num_correct += get_correct(preds, batch.label) total_num += len(batch) if args.algo == "vae": q_sum = q_sum + identity.index_select("index", batch.label).dot("batch", e) if args.algo == "attn" and args.visualize_freq and i % args.visualize_freq == 0: fname = './img/' + args.save_img visualize_attn(e[{'batch': 0}], batch.premise[{'batch': 0}], batch.hypothesis[{'batch': 0}], save_name=fname) if args.algo == "vae": q_sum = q_sum / q_sum.mean("model") print(q_sum._tensor.cpu().data.values) return total_loss / total_num, 100.0 * num_correct / total_num
def forward(self, text, aa_info): ''' Pass in context for the next amino acid ''' # Reset for each new batch... h_0 = ntorch.zeros(text.shape["batch"], self.num_layers, self.hiddenlen, names=("batch", "layers", "hiddenlen")).to(self.device) c_0 = ntorch.zeros(text.shape["batch"], self.num_layers, self.hiddenlen, names=("batch", "layers", "hiddenlen")).to(self.device) # If we should use all the sequence as input if self.teacher_force_prob == 1: text_embedding = self.embedding(text) hidden_states, (h_n, c_n) = self.LSTM(text_embedding, (h_0, c_0)) output = self.linear_dropout(hidden_states) output = ntorch.cat([output, aa_info], dim="hiddenlen") output = self.linear(output) # If we should use some combination of teacher forcing else: # Use for teacher forcing... outputs = [] model_input = text[{"seqlen" : slice(0, 1)}] h_n, c_n = h_0, c_0 for position in range(text.shape["seqlen"]): text_embedding = self.embedding(model_input) hidden_states, (h_n, c_n) = self.LSTM(text_embedding, (h_n, c_n)) output = self.linear_dropout(hidden_states) aa_info_subset = aa_info[{"seqlen" : slice(position, position+1)}] output = ntorch.cat([output, aa_info_subset], dim="hiddenlen") output = self.linear(output) outputs.append(output) # Define next input... if random.random() < self.teacher_force_prob: model_input = text[{"seqlen" : slice(position, position+1)}] else: # TODO: Should we be masking this output? model_input = output.argmax("vocablen") output = ntorch.cat(outputs, dim="seqlen") return output
def forward(self, seq): ''' Forward pass ''' # Replace start codon... seq_copy = seq.clone() seq_copy[{"seqlen" : 0}] = self.start_index seq = seq_copy aa_rep = self.aa_embed(seq) h_0 = ntorch.zeros(self.num_layers * self.num_directions, aa_rep.shape["batch"], self.hiddenlen, names=("layers", "batch", "hiddenlen")).to(self.device) c_0 = ntorch.zeros(self.num_layers * self.num_directions, aa_rep.shape["batch"], self.hiddenlen, names=("layers", "batch", "hiddenlen")).to(self.device) h_0 = h_0.transpose("batch", "layers", "hiddenlen") c_0 = c_0.transpose("batch", "layers", "hiddenlen") hidden_states, (h_n, c_n) = self.LSTM(aa_rep, (h_0, c_0)) return hidden_states
def loss_function(recon_x, x, var_posterior): BCE = recon_x.reduce2( x.stack(h=("ch", "height", "width")), lambda x, y: F.binary_cross_entropy(x, y, reduction="sum"), ("batch", "x"), ) prior = ndistributions.Normal(ntorch.zeros(dict(batch=1, z=1)), ntorch.ones(dict(batch=1, z=1))) KLD = ndistributions.kl_divergence(var_posterior, prior).sum() return BCE + KLD
def pe(self): pe = ntorch.zeros( MAX_LEN, self.d_model, names=(self.dim_length, self.dim_hidden) ) position = ntorch.arange(0, MAX_LEN, names=self.dim_length).float() shift = ntorch.arange(0, self.d_model, 2, names=self.dim_hidden) div_term = ntorch.exp( shift.float() * -(math.log(10000.0) / self.d_model) ) val = ntorch.mul(position, div_term) print(val.shape, shift.shape, pe.shape) pe[{self.dim_hidden: shift}] = val.sin() pe[{self.dim_hidden: shift + 1}] = val.cos() return pe
def __init__(self, vocab, num_classes, padding_idx): super(LR, self).__init__() vocab_size = len(vocab.itos) self.lut = ntorch.nn.Embedding( vocab_size, num_classes, padding_idx=padding_idx, ).augment('classes') #self.bias = ntorch.zeros(dict(classes=num_classes)) self.bias = ntorch.zeros(num_classes, names=("classes", )) self.bias_param = nn.Parameter(self.bias._tensor) self.bias._tensor = self.bias_param self.loss_fn = ntorch.nn.NLLLoss(reduction='sum') \ .reduce(('batch', 'classes'))
def beam(self, src, trg, k, beam_len, num_candidates): batch_size = src.shape['batch'] out_dists = HypothesisMap( device=self.device) # map a hypothesis to distribution over words scores = HypothesisMap( keys=[trg[{ 'trgSeqlen': slice(0, 1) }]], vals=[ntorch.zeros(batch_size, names='batch')], device=self.device) # map a hypothesis to its score end = HypothesisMap( device=self.device) # special buffer for hyptothesis with <EOS> attn = [] EOS_IND = 3 hidden = self.encoder(src) # make predictions for l in range(beam_len or trg.shape['trgSeqlen'] - 1): new_scores = HypothesisMap(device=self.device) hyps = scores.get_topk(k) if l > 0 else scores for hyp, score in hyps.items(): inp = hyp[{'trgSeqlen': slice(l, l + 1)}] out, hidden = self.decoder(inp, hidden) out = out.log_softmax('logit') topk = out.topk('logit', k) for i in range(k): pred_prob = topk[0][{'logit': i, 'trgSeqlen': -1}] pred = topk[1][{'logit': i}] new_hyp = ntorch.cat([hyp, pred], 'trgSeqlen') if hyp in out_dists: out_dists[new_hyp] = ntorch.cat([out_dists[hyp], out], 'trgSeqlen') else: out_dists[new_hyp] = out if torch.any((pred[{'trgSeqlen': -1}] == EOS_IND).values): end[new_hyp] = score + pred_prob end[new_hyp].masked_fill_( pred[{ 'trgSeqlen': -1 }] != EOS_IND, -float('inf')) pred_prob.masked_fill_( pred[{ 'trgSeqlen': -1 }] == EOS_IND, -float('inf')) new_scores[new_hyp] = score + pred_prob scores = new_scores for hyp, score in end.items(): scores[hyp] = score best = scores.get_topk(num_candidates).keys out = [out_dists[k] for k in best] #store output if 'attn' in hidden: attn.append(hidden['attn']) #format predictions return ntorch.stack(out, 'candidates'), ntorch.cat(attn, dim='trgSeqlen')
def trajectories(self, N=100, dt=0.02): perimeter = self.params['perimeter'] T = self.params["T"] n = int(T / dt) mu, sigma, b = [ self.params[i] for i in ["mean_rotation", "std_dev_rotation", "std_dev_forward"] ] rotation_velocities = torch.tensor( np.random.normal(mu, sigma, size=(n, N))).float() forward_velocities = torch.tensor(np.random.rayleigh( b, size=(n, N))).float() positions = ntorch.zeros((n, 2, N), names=("t", "ax", "sample")) vs = torch.zeros((n, N)) angles = rotation_velocities directions = torch.zeros((n, 2, N)) vs[0] = self.params["v0"] theta = torch.rand(N) * 2 * np.pi directions[0] = unit_vector(theta) positions[{ "t": 0 }] = ntorch.tensor(self.scene.random(N), names=("sample", "ax")) for i in range(1, n): dist, phi = self.scene.closestWall(positions[{ "t": i - 1 }].values, directions[i - 1]) wall = (dist < perimeter) & (phi.abs() < np.pi / 2) angle_correction = torch.where( wall, phi.sign() * (np.pi / 2 - phi.abs()), torch.zeros_like(phi)) angles[i] += angle_correction vs[i] = torch.where( wall, (1 - self.params["velocity_reduction"]) * (vs[i - 1]), forward_velocities[i], ) positions[{ "t": i }] = (positions[{ "t": i - 1 }] + directions[i - 1] * vs[i] * dt) mat = rotation_matrix(angles[i] * dt) directions[i] = torch.einsum("ijk,jk->ik", mat, directions[i - 1]) idx = np.round(np.linspace( 0, n - 2, self.params["trajectory_length"])).astype(int) # idx = np.array(sorted(np.random.choice(np.arange(n), size=self.params["trajectory_length"], replace=False))) dphis = ntorch.tensor(angles[idx] * dt, names=("t", "sample")) velocities = ntorch.tensor(vs[idx], names=("t", "sample")) vel = ntorch.stack((velocities, dphis.cos(), dphis.sin()), "input") xs = ntorch.tensor(positions.values[idx], names=("t", "ax", "sample")) # xs0 = positions[{'t': 0}] xs0 = ntorch.tensor(self.scene.random(n=N), names=("sample", "ax")) hd = torch.atan2(directions[:, 1], directions[:, 0]) hd0 = ntorch.tensor(hd[0][None], names=("hd", "sample")) hd = ntorch.tensor(hd[idx + 1][None], names=("hd", "t", "sample")) xs = xs.transpose('sample', 't', 'ax') hd = hd.transpose('sample', 't', 'hd') vel = vel.transpose('sample', 't', 'input') xs0 = xs0.transpose('sample', 'ax') hd0 = hd0.transpose('sample', 'hd') return xs, hd, vel, xs0, hd0
def forward(self, source, target=None, teacher_forcing=1., max_length=20, encode_only=False): if target: max_length = target.shape["trgSeqlen"] x = self.in_embedding(source) out, (h, c) = self.encoder(x) h = ntorch.cat((h[{ "layers": slice(0, 1) }], h[{ "layers": slice(1, 2) }]), dim="rnnOutput") c = ntorch.cat((c[{ "layers": slice(0, 1) }], c[{ "layers": slice(1, 2) }]), dim="rnnOutput") if self.attention: def attend(x_t): alpha = out.dot("rnnOutput", x_t).softmax("srcSeqlen") context = alpha.dot("srcSeqlen", out) return context batch_size = source.shape["batch"] output_dists = ntorch.zeros( (batch_size, max_length, self.out_vocab_size), names=("batch", "trgSeqlen", "outVocab"), device=device) output_seq = ntorch.zeros((batch_size, max_length), names=("batch", "trgSeqlen"), device=device) #for the above, should set zeroith index to SOS score = ntorch.zeros((batch_size, max_length), names=("batch", "trgSeqlen"), device=device) if encode_only: return score, out, (h, c), output_seq for t in range(max_length - 1): #Oh god if t == 0: # always start with SOS token next_input = ntorch.ones((batch_size, 1), names=("batch", "trgSeqlen"), device=device).long() next_input *= EN.vocab.stoi["<s>"] elif np.random.random( ) < teacher_forcing and target: # we will force next_input = target[{"trgSeqlen": slice(t, t + 1)}] else: next_input = sample x_t, (h, c) = self.decoder(self.out_embedding(next_input), (h, c)) if t == 0: syntax_out, (s_h, s_c) = self.syntax_decoder( self.out_embedding(next_input)) else: syntax_out, (s_h, s_c) = self.syntax_decoder( self.out_embedding(next_input), (s_h, s_c)) if self.attention: fc = self.fc(ntorch.cat([attend(x_t), x_t], dim="rnnOutput")) else: fc = self.fc(x_t) s_fc = self.syntax_fc(syntax_out).sum("trgSeqlen") s_fc = s_fc.log_softmax("outVocab") dist = ntorch.distributions.Categorical(logits=fc, dim_logit="outVocab") sample = dist.sample() fc = fc.sum("trgSeqlen") next_token = (sample) if not target else target[{ "trgSeqlen": slice(t + 1, t + 2) }] #TODO #this is the line where the syntax thing does it's stuff fc = fc.log_softmax("outVocab") + s_fc indices = next_token.sum("trgSeqlen").rename("batch", "indices") batch_indices = ntorch.tensor( torch.tensor(np.arange(fc.shape["batch"]), device=device), ("batchIndices")) newsc = fc.index_select("outVocab", indices).index_select( "indices", batch_indices).get("batchIndices", 0) score[{"trgSeqlen": t + 1}] = newsc output_seq[{ "trgSeqlen": t + 1 }] = next_token.sum("trgSeqlen") #todo output_dists[{"trgSeqlen": t + 1}] = fc #Todo return output_seq, output_dists, score
def forward(self, x, s, x_info, r, r_info, vt, ue, ue_info, ut, ut_info, v2d, vt2d, y=None): emb = self.lutx(x) N = emb.shape["batch"] T = emb.shape["time"] e = self.lute(r[0]).rename("e", "r") t = self.lutt(r[1]).rename("t", "r") v = self.lutv(r[2]).rename("v", "r") # r: R x N x Er, Wa r: R x N x H r = self.War(ntorch.cat([e, t, v], dim="r").tanh()) eA = self.Wae(self.lute(ue)) tA = self.Wat(self.lutt(ut)) if self.v2d: v2dx = self.lutv(v2d.stack( ("t", "e"), "els")).chop("els", ("t", "e"), t=v2d.shape["t"]).rename("v", "rnns") else: # vt2dx v2dx = self.lutx(vt2d.stack( ("t", "e"), "time")).chop("time", ("t", "e"), t=v2d.shape["t"]).rename("x", "rnns") if not self.inputfeed: # rnn_o: T x N x H rnn_o, s = self.rnn(emb, s, x_info.lengths) # ea: T x N x R log_e, ea_E, ec_E = attn(rnn_o, eA, ue_info.mask) #log_t, ea_T, ec_T = attn(rnn_o + ec_E, tA, ut_info.mask) log_t, ea_T, ec_T = attn(rnn_o, tA, ut_info.mask) if self.noattn: ec = r.mean("els").repeat("time", ec.shape["time"]) ec_ET = ec_E + ec_T le = log_e.rename("els", "e") lt = log_t.rename("els", "t") self.ea = le self.ta = lt aw = (le + lt).exp() ec = aw.dot(("t", "e"), v2dx) self.a = aw # no ent or typ #out = self.Wc(ntorch.cat([rnn_o, ec], "rnns")).tanh() # cat ent and type, this seems fine out = (self.Wc_nov(ntorch.cat([rnn_o, ec_E, ec_T], "rnns")) if self.noattnvalues else self.Wc( ntorch.cat([rnn_o, ec, ec_E, ec_T], "rnns"))).tanh() # add ent and typ #out = self.Wc(ntorch.cat([rnn_o, ec + ec_ET], "rnns")).tanh() else: out = [] self.ea = [] self.ta = [] self.a = [] ec_ETt = ntorch.zeros(N, self.r_emb_sz, names=("batch", "rnns")).to(emb.values.device) for t in range(T): ec_ETt = ec_ETt.rename("rnns", "x") inp = ntorch.cat([emb.get("time", t), ec_ETt], "x").repeat("time", 1) rnn_o, s = self.rnn(inp, s) rnn_o = rnn_o.get("time", 0) log_e, ea_Et, ec_Et = attn(rnn_o, eA, ue_info.mask) log_t, ea_Tt, ec_Tt = attn(rnn_o, tA, ut_info.mask) ec_ETt = ec_Et + ec_Tt le = log_e.rename("els", "e") lt = log_t.rename("els", "t") aw = (le + lt).exp() ect = aw.dot(("t", "e"), v2dx) out.append(ntorch.cat([rnn_o, ect, ec_Et, ec_Tt], "rnns")) self.ea.append(ea_Et.detach()) self.ta.append(ea_Tt.detach()) self.a.append(aw.detach()) out = self.Wc(ntorch.stack(out, "time")).tanh() # return unnormalized vocab return self.proj(self.drop(out)), s
def pa0(self, emb_x, s, x_info, emb_e, ue_info, emb_t, ut_info, v2dx): T = emb_x.shape["time"] N = emb_x.shape["batch"] log_ea, ea, ec = None, None, None log_ta, ta, tc = None, None, None log_a, a, c = None, None, None output = None if not self.inputfeed: # rnn_o: T x N x H rnn_o, s = self.rnn(emb_x, s, x_info.lengths) # ea: T x N x R log_ea, ea, ec = attn(rnn_o, emb_e, ue_info.mask) #log_t, ea_T, ec_T = attn(rnn_o + ec_E, tA, ut_info.mask) log_ta, ta, tc = attn(rnn_o, emb_t, ut_info.mask) if self.noattn: ec = r.mean("els").repeat("time", ec.shape["time"]) log_ea = log_ea.rename("els", "e") log_ta = log_ta.rename("els", "t") log_va = log_ea + log_ta vc = log_va.exp().dot(("t", "e"), v2dx) va = log_va.exp() ea = ea.rename("els", "e") ta = ta.rename("els", "t") output = rnn_o else: log_ea, ea, ec = [], [], [] log_ta, ta, tc = [], [], [] log_va, va, vc = [], [], [] out = [] etc_t = ntorch.zeros( N, self.r_emb_sz, names=("batch", "rnns") ).to(emb_x.values.device) for t in range(T): etc_t = etc_t.rename("rnns", "x") inp = ntorch.cat([emb_x.get("time", t), etc_t], "x").repeat("time", 1) rnn_o, s = self.rnn(inp, s) rnn_o = rnn_o.get("time", 0) log_ea_t, ea_t, ec_t = attn(rnn_o, emb_e, ue_info.mask) log_ta_t, ta_t, tc_t = attn(rnn_o, emb_t, ut_info.mask) log_ea_t = log_ea_t.rename("els", "e") log_ta_t = log_ta_t.rename("els", "t") log_va_t = log_ea_t + log_ta_t va_t = log_va_t.exp() vc_t = va_t.dot(("t", "e"), v2dx) out.append( self.Wif(ntorch.cat([rnn_o, vc_t, ec_t, tc_t], "rnns")) ) log_ea.append(log_ea_t) ea.append(ea_t) ec.append(ec_t) log_ta.append(log_ta_t) ta.append(ta_t) tc.append(tc_t) log_va.append(log_va_t) va.append(va_t) vc.append(vc_t) output = ntorch.stack(out, "time") log_ea = ntorch.stack(log_ea, "time") ea = ntorch.stack(ea, "time") ec = ntorch.stack(ec, "time") log_ta = ntorch.stack(log_ta, "time") ta = ntorch.stack(ta, "time") tc = ntorch.stack(tc, "time") log_va = ntorch.stack(log_va, "time") va = ntorch.stack(va, "time") vc = ntorch.stack(vc, "time") ea = ea.rename("els", "e") ta = ta.rename("els", "t") return log_ea, ea, ec, log_ta, ta, tc, log_va, va, vc, output, s