def elbo_reinforce(self, premise, hypothesis, label): # computing the q distribution: p(c | a, b, y) q = self.q(premise, hypothesis, label).rename('label', 'latent') latent_dist = ds.Categorical(logits=q, dim_logit='latent') # generating some samples samples = latent_dist.sample([self.sample_size], names=('samples', )) # bucketing samples by the sampled model to maximize efficiency buckets = defaultdict(list) premise_lst = premise.unbind('batch') hypothesis_lst = hypothesis.unbind('batch') samples_list = samples.transpose('batch', 'samples').tolist() for i, batch in enumerate(samples_list): p, h = premise_lst[i], hypothesis_lst[i] for sample in batch: buckets[sample].append((i, p, h)) # evaluating the sampled models efficiently using batching orig_batch_size = premise.shape['batch'] counts = [0] * orig_batch_size res = [None] * (self.sample_size * orig_batch_size) correct = label.tolist() for c, items in buckets.items(): # stacking data points into batches batch_premise = ntorch.stack([p for _, p, _ in items], 'batch') batch_hypothesis = ntorch.stack([h for _, _, h in items], 'batch') ids = [i for i, _, _ in items] # evaluating the model on that batch predictions = self.models[c](batch_premise, batch_hypothesis) # updating the result at the appropriate index for i, log_probs in zip(ids, predictions.unbind('batch')): res[self.sample_size * i + counts[i]] = log_probs.values[correct[i]] counts[i] += 1 # reforming and averaging the results for each sample res = torch.stack(res, dim=0).reshape(orig_batch_size, self.sample_size) res = ntorch.tensor(res, names=('batch', 'sample')) # computing a surrogate objective for REINFORCE # https://pyro.ai/examples/svi_part_iii.html q_log_prob = latent_dist.log_prob(samples) surrogate_objective = (q_log_prob * res.detach() + res).mean('sample') # adding on the KL regularizing term ones = ntorch.ones(self.K, names='latent').log_softmax(dim='latent') uniform_dist = ds.Categorical(logits=ones, dim_logit='latent') kl = ds.kl_divergence(latent_dist, uniform_dist) * self.kl_importance # reporting the surrogate objective as well as the actual elbo loss = -(surrogate_objective - kl).mean() elbo = -(res.detach().mean('sample') - kl.detach()).mean() return loss, elbo
def reinforce(self, premise, hypothesis, label): # REINFORCE q = self.q(premise, hypothesis, label).rename('label', 'latent') latent_dist = nds.Categorical(logits=q, dim_logit='latent') # Sample to appromixate E[] samples = latent_dist.sample([self.num_samples], names=('samples', )) # Batch premises and hypotheses batches = defaultdict(list) premise_n = premise.unbind('batch') hypothesis_n = hypothesis.unbind('batch') # Get some samples samples_n = samples.transpose('batch', 'samples').tolist() # Idea is to work with samples based on their sampled model for i, batch in enumerate(samples_n): p = premise_n[i] h = hypothesis_n[i] for sample in batch: batches[sample].append((i, p, h)) # Can now evaluate sampled models with batching batch_size = premise.shape['batch'] counts = [0] * batch_size res = [None] * (self.num_samples * batch_size) correct = label.tolist() for i, items in batches.items(): # for item in items: # batch_p = ntorch. batch_p = ntorch.stack([p for _, p, _ in items], 'batch') batch_h = ntorch.stack([h for _, _, h in items], 'batch') batch_i = [i for i, _, _ in items] # Evaluate model per batch, then update preds = self.models[i](batch_p, batch_h) for i, log_probs in zip(batch_i, preds.unbind('batch')): res[self.num_samples * i + counts[i]] = log_probs.values[correct[i]] counts[i] += 1 # Finally average results for sample res = torch.stack(res, dim=0).reshape(batch_size, self.num_samples) res = ntorch.tensor(res, names=( 'batch', 'sample', )) # Onward to estimating gradient + calculating loss surrogate = (latent_dist.log_prob(samples) * res.detach() + res).mean('sample') prior = ntorch.ones(self.K, names='latent').log_softmax(dim='latent') prior = nds.Categorical(logits=prior, dim_logit='latent') KLD = nds.kl_divergence(latent_dist, prior) * self.kl_weight loss = (KLD - surrogate._tensor).mean() # -(surrogate = kl) elbo = (KLD.detach() - res.detach().mean('sample')._tensor).mean() return loss, elbo
def get_topk(self, k): """ get the topk items as a HypothesisMap """ #keys = ntorch.stack(self.keys, 'map').to(self.device) vals = ntorch.stack(self.vals, 'map').to(self.device) vals, inds = vals.topk('map', k) keys_list = [] for m in range(inds.shape['map']): newbatch = [] for b in range(inds.shape['batch']): newbatch.append(self.keys[m][{'batch': b}]) keys_list.append(ntorch.stack(newbatch, 'batch')) vals_list = [vals[{'map': i}] for i in range(vals.shape['map'])] return HypothesisMap(keys=keys_list, vals=vals_list, device=self.device)
def state_to_tensor(self, states): inputs, scratchs, committeds, outputs, masks, last_actions = zip( *states) inputs = np.stack(inputs) input_tensor = ntorch.tensor(inputs, ('batch', 'Examples', 'strLen')) scratchs = np.stack(scratchs) scratch_tensor = ntorch.tensor(scratchs, ('batch', 'Examples', 'strLen')) committeds = np.stack(committeds) committed_tensor = ntorch.tensor(committeds, ('batch', 'Examples', 'strLen')) outputs = np.stack(outputs) output_tensor = ntorch.tensor(outputs, ('batch', 'Examples', 'strLen')) chars = ntorch.stack( [input_tensor, scratch_tensor, committed_tensor, output_tensor], 'stateLoc') chars = chars.transpose('batch', 'Examples', 'strLen', 'stateLoc').long() # print(chars.shape) masks = np.stack(masks) masks = ntorch.tensor(masks, ('batch', 'Examples', 'inFeatures', 'strLen')) # print(masks.shape) masks = masks.transpose('batch', 'Examples', 'strLen', 'inFeatures').float() last_actions = np.stack(last_actions) last_actions = ntorch.tensor(last_actions, 'batch').long() if self.use_cuda: return chars.cuda(), masks.cuda(), last_actions.cuda() else: return chars, masks, last_actions
def predict(self, x): # y = self.Wb #y = (self.W.index_select('vocab', x.long()).sum('vocab') + self.b).sigmoid() y_ = self.W(x).sigmoid().sum( 'singular').sigmoid() # this is a huge hack y = ntorch.stack([y_, 1 - y_], 'classes') #.log_softmax('classes') return y
def enumerate(self, premise, hypothesis, models=None): predictions = [] for model in (models or self.models): predictions.append(model(premise, hypothesis)) return (ntorch.stack( predictions, "experts").softmax('logit').mean('experts').log().rename( 'logit', 'logprob'))
def forward(self, a, b): """ The inputs are vectors, for now a: batch x seqlenA x embedding b: batch x seqlenB x embedding """ y = ntorch.stack([model(a, b) for model in self.models], name='ensemble').mean('ensemble') return y
def predict(self, x): #y = ntorch.tensor(torch.sign(self.W.index_select(x, 'vocab').sum('vocab') + self.b), ['classes', 'batch']) #TODO: sign function, mm y_ = self.W.dot('vocab', x) + self.b # tensor_a = ntorch.tensor(torch.Tensor([[1, 2], [3, 4]]), ("dim1", "dim2") # tensor_b = ntorch.tensor(torch.Tensor([[1, 2], [3, 4]]), ("dim1", "dim2")) # tensor_c = ntorch.stack([tensor_a, tensor_b], "dim3") print("y_", y_) y = ntorch.stack([y_ < 0, y_ >= 0], 'classes') return y
def kl(self, qa, log_qa, log_pa, lens, dim): kl = [] for i, l in enumerate(lens.tolist()): qa0 = qa.get("batch", i).narrow(dim, 0, l) log_qa0 = log_qa.get("batch", i).narrow(dim, 0, l) log_pa0 = log_pa.get("batch", i).narrow(dim, 0, l) kl0 = qa0 * (log_qa0 - log_pa0) infmask = log_qa0 != float("-inf") # workaround for namedtensor bug that puts empty tensors on different devices kl0 = kl0._new(kl0.values.where(infmask.values, torch.zeros_like(kl0.values))).sum(dim) kl.append(kl0) return ntorch.stack(kl, "batch")
def kl_qz(qz, log_qz, log_pz, lens, dim, batchdim="time"): # batchdim other than "batch", could be time or els kl = [] for i, l in enumerate(lens.tolist()): qz0 = qz.get("batch", i).narrow(dim, 0, l) log_qz0 = log_qz.get("batch", i).narrow(dim, 0, l) log_pz0 = log_pz.get("batch", i).narrow(dim, 0, l) kl0 = qz0 * (log_qz0 - log_pz0) infmask = log_qz0 != float("-inf") # workaround for namedtensor bug that puts empty tensors on different devices kl0 = kl0.transpose(batchdim, dim) kl0 = kl0._new( kl0.values.where(infmask.values, torch.zeros_like(kl0.values))).sum(dim) kl.append(kl0) return ntorch.stack(kl, "batch")
def forward(self, x, s, x_info, r, r_info, ue, ue_info, ut, ut_info, v2d): emb = self.lutx(x) N = emb.shape["batch"] T = emb.shape["time"] e = self.lute(r[0]).rename("e", "r") t = self.lutt(r[1]).rename("t", "r") v = self.lutv(r[2]).rename("v", "r") # r: R x N x Er, Wa r: R x N x H r = self.Wa(ntorch.cat([e, t, v], dim="r").tanh()) if not self.inputfeed: # rnn_o: T x N x H rnn_o, s = self.rnn(emb, s, x_info.lengths) # ea: T x N x R _, ea, ec = attn(rnn_o, r, r_info.mask) if self.noattn: ec = r.mean("els").repeat("time", ec.shape["time"]) self.ea = ea out = self.Wc(ntorch.cat([rnn_o, ec], "rnns")).tanh() else: out = [] ect = NamedTensor( torch.zeros(N, self.r_emb_sz).to(emb.values.device), names=("batch", "rnns"), ) for t in range(T): inp = ntorch.cat([emb.get("time", t), ect.rename("rnns", "x")], "x").repeat("time", 1) rnn_o, s = self.rnn(inp, s) rnn_o = rnn_o.get("time", 0) _, eat, ect = attn(rnn_o, r, r_info.mask) out.append(ntorch.cat([rnn_o, ect], "rnns")) out = self.Wc(ntorch.stack(out, "time")).tanh() # return unnormalized vocab return self.proj(self.drop(out)), s
def predict(self, x): y_ = self.U(self.V(x).relu().sum("seqlen")).sigmoid().sum('score') y = ntorch.stack([y_, 1 - y_], 'classes') return y
def pa0(self, emb_x, s, x_info, emb_e, ue_info, emb_t, ut_info, v2dx): T = emb_x.shape["time"] N = emb_x.shape["batch"] log_ea, ea, ec = None, None, None log_ta, ta, tc = None, None, None log_a, a, c = None, None, None output = None if not self.inputfeed: # rnn_o: T x N x H rnn_o, s = self.rnn(emb_x, s, x_info.lengths) # ea: T x N x R log_ea, ea, ec = attn(rnn_o, emb_e, ue_info.mask) #log_t, ea_T, ec_T = attn(rnn_o + ec_E, tA, ut_info.mask) log_ta, ta, tc = attn(rnn_o, emb_t, ut_info.mask) if self.noattn: ec = r.mean("els").repeat("time", ec.shape["time"]) log_ea = log_ea.rename("els", "e") log_ta = log_ta.rename("els", "t") log_va = log_ea + log_ta vc = log_va.exp().dot(("t", "e"), v2dx) va = log_va.exp() ea = ea.rename("els", "e") ta = ta.rename("els", "t") output = rnn_o else: log_ea, ea, ec = [], [], [] log_ta, ta, tc = [], [], [] log_va, va, vc = [], [], [] out = [] etc_t = ntorch.zeros( N, self.r_emb_sz, names=("batch", "rnns") ).to(emb_x.values.device) for t in range(T): etc_t = etc_t.rename("rnns", "x") inp = ntorch.cat([emb_x.get("time", t), etc_t], "x").repeat("time", 1) rnn_o, s = self.rnn(inp, s) rnn_o = rnn_o.get("time", 0) log_ea_t, ea_t, ec_t = attn(rnn_o, emb_e, ue_info.mask) log_ta_t, ta_t, tc_t = attn(rnn_o, emb_t, ut_info.mask) log_ea_t = log_ea_t.rename("els", "e") log_ta_t = log_ta_t.rename("els", "t") log_va_t = log_ea_t + log_ta_t va_t = log_va_t.exp() vc_t = va_t.dot(("t", "e"), v2dx) out.append( self.Wif(ntorch.cat([rnn_o, vc_t, ec_t, tc_t], "rnns")) ) log_ea.append(log_ea_t) ea.append(ea_t) ec.append(ec_t) log_ta.append(log_ta_t) ta.append(ta_t) tc.append(tc_t) log_va.append(log_va_t) va.append(va_t) vc.append(vc_t) output = ntorch.stack(out, "time") log_ea = ntorch.stack(log_ea, "time") ea = ntorch.stack(ea, "time") ec = ntorch.stack(ec, "time") log_ta = ntorch.stack(log_ta, "time") ta = ntorch.stack(ta, "time") tc = ntorch.stack(tc, "time") log_va = ntorch.stack(log_va, "time") va = ntorch.stack(va, "time") vc = ntorch.stack(vc, "time") ea = ea.rename("els", "e") ta = ta.rename("els", "t") return log_ea, ea, ec, log_ta, ta, tc, log_va, va, vc, output, s
def marginal_nll(self, x, s, x_info, r, r_info, vt, y, y_info, T=128, E=32, R=4, learn=False, ): assert(learn == False) # r: R x N x Er # Wa r: R x N x H e = self.lute(r[0]).rename("e", "r") t = self.lutt(r[1]).rename("t", "r") v = self.lutv(r[2]).rename("v", "r") #r = self.Wa(ntorch.cat([e, t, v], "r").tanh()) r = ntorch.cat([e, t, v], "r") rW = self.Wa(r) emb_x = self.lutx(x) log_pa, pa, ec, rnn_o, s = self.pa0(emb_x, s, rW, x_info, r_info) # what should we do with pa, ec, and rnn_o? # ec is only used for a baseline R, N, H = r.shape K = self.K """ if y is not None: emb_y = self.lutx(y) log_pa_y, pay, eyc = self.pa_y(emb_y, r, x_info, r_info) else: log_pa_y, pay = None, None """ ctxt = rW lyas = [] # sum_Ai sum_Cj [ sum_{a in Ai} p(a) sum_{c in Ci} p(y|a,c)p(c|a) ] for rnn_t, lpa_t, y_t in zip( rnn_o.split(T, "time"), log_pa.split(T, "time"), y.split(T, "time"), ): lyas_t = [] for ctxt_e, lpa_e, vt_e in zip( ctxt.split(E, "els"), lpa_t.split(E, "els"), vt.split(E, "els"), ): out = self.Wc(ntorch.cat( [ rnn_t.expand("els", ctxt_e.shape["els"]), ctxt_e.expand("time", rnn_t.shape["time"]) ], "rnns", )).tanh() # is this right #import pdb; pdb.set_trace() log_pc_a = self.log_pc_a(rnn_t, ctxt_e) log_pc0_a = log_pc_a.get("copy", 0) log_pc1_a = log_pc_a.get("copy", 1) # TODO: Need to generalize conditional dists. # log p(y|a) log_py_ac0 = (self.proj(out) .log_softmax("vocab") .gather("vocab", y_t.chop("batch", ("lol", "batch"), lol=1), "lol") ).get("lol", 0) #log_py_ac1 = (vt_e == y_t).float().log() py_ac1 = vt_e == y_t py_ac1 = py_ac1._new(py_ac1.type(self.dtype)) log_py_ac1 = py_ac1.log() log_py_a = logaddexp( log_py_ac0 + log_pc0_a, log_py_ac1 + log_pc1_a, ) # log p(y|a,c=0) log_pya = (log_py_a + lpa_e) lyas_t.append(log_pya.logsumexp("els")) lyas.append(ntorch.stack(lyas_t, "els").logsumexp("els")) # log p(y) log_py = ntorch.cat(lyas, "time") rvinfo = RvInfo( log_py = log_py, ) # need E log p(y|a)?? return rvinfo, s
def __init__(self, data=None, dataset=None, device=None): """Create a Batch from a list of examples.""" if data is not None: self.batch_size = len(data) self.dataset = dataset self.fields = dataset.fields.keys() # copy field names self.input_fields = [k for k, v in dataset.fields.items() if v is not None and not v.is_target] self.target_fields = [k for k, v in dataset.fields.items() if v is not None and v.is_target] for (name, field) in dataset.fields.items(): if field is not None: batch = [getattr(x, name) for x in data] setattr(self, name, field.process(batch, device=device)) # 2d attn maxe = max(len(x.uentities) for x in data) maxt = max(len(x.utypes) for x in data) padded = [] tostack = [] tostack_v = [] for x in data: ue = x.uentities ut = x.utypes lene = len(ue) lent = len(ut) etmap = {(e,t): v for e,t,v in zip(x.entities, x.types, x.values)} array = [ [etmap[(e, t)] if (e,t) in etmap else NONE for t in ut] + [PAD] * (maxt - lent) for e in ue ] + [[PAD] * maxt] * (maxe - lene) tensor = dataset.fields["values_text"].numericalize( (array, [lent] * lene), device=device) tensor_v = dataset.fields["values"].numericalize( (array, [lent] * lene), device=device) ue, ue_info = self.uentities ut, ut_info = self.utypes padded.append(array) tostack.append(tensor[0].rename("els", "t").rename("batch", "e")) tostack_v.append(tensor_v[0].rename("els", "t").rename("batch", "e")) setattr(self, "vt2d", ntorch.stack(tostack, "batch")) setattr(self, "v2d", ntorch.stack(tostack_v, "batch")) # ie stuff ie_etv = [] ie_d = [] num_cells = 0 for x in data: etvx = x.ie_etv etvs = [] for etv in etvx: T = len(etv) e = [x[0] for x in etv] t = [x[1] for x in etv] v0 = [x[2] for x in etv] e, _ = dataset.fields["entities"].numericalize( ([e], [T]), device="cpu") t, _ = dataset.fields["types"].numericalize( ([t], [T]), device="cpu") v, _ = dataset.fields["values"].numericalize( ([v0], [T]), device="cpu") vt, _ = dataset.fields["values_text"].numericalize( ([v0], [T]), device="cpu") etvs.append(( e.get("batch", 0).rename("els", "e"), t.get("batch", 0).rename("els", "t"), v.get("batch", 0).rename("els", "v"), vt.get("batch", 0).rename("els", "x"), )) num_cells += T ie_etv.append(etvs) ie_et_d = x.ie_et_d d = {} for k, v in ie_et_d.items(): T = len(v) e = [x[0] for x in v] t = [x[1] for x in v] e, _ = dataset.fields["entities"].numericalize( ([e], [T]), device="cpu") t, _ = dataset.fields["types"].numericalize( ([t], [T]), device="cpu") d[k] = ( e.get("batch", 0).rename("els", "e"), t.get("batch", 0).rename("els", "t"), ) num_cells += T ie_d.append(d) setattr(self, "ie_etv", ie_etv) setattr(self, "ie_et_d", ie_d) setattr(self, "num_cells", num_cells) # indices idxs = [x.idx for x in data] setattr(self, "idxs", idxs)
def forward(self, x, s, x_info, r, r_info, vt, ue, ue_info, ut, ut_info, v2d, vt2d, y=None): emb = self.lutx(x) N = emb.shape["batch"] T = emb.shape["time"] e = self.lute(r[0]).rename("e", "r") t = self.lutt(r[1]).rename("t", "r") v = self.lutv(r[2]).rename("v", "r") # r: R x N x Er, Wa r: R x N x H r = self.War(ntorch.cat([e, t, v], dim="r").tanh()) eA = self.Wae(self.lute(ue)) tA = self.Wat(self.lutt(ut)) if self.v2d: v2dx = self.lutv(v2d.stack( ("t", "e"), "els")).chop("els", ("t", "e"), t=v2d.shape["t"]).rename("v", "rnns") else: # vt2dx v2dx = self.lutx(vt2d.stack( ("t", "e"), "time")).chop("time", ("t", "e"), t=v2d.shape["t"]).rename("x", "rnns") if not self.inputfeed: # rnn_o: T x N x H rnn_o, s = self.rnn(emb, s, x_info.lengths) # ea: T x N x R log_e, ea_E, ec_E = attn(rnn_o, eA, ue_info.mask) #log_t, ea_T, ec_T = attn(rnn_o + ec_E, tA, ut_info.mask) log_t, ea_T, ec_T = attn(rnn_o, tA, ut_info.mask) if self.noattn: ec = r.mean("els").repeat("time", ec.shape["time"]) ec_ET = ec_E + ec_T le = log_e.rename("els", "e") lt = log_t.rename("els", "t") self.ea = le self.ta = lt aw = (le + lt).exp() ec = aw.dot(("t", "e"), v2dx) self.a = aw # no ent or typ #out = self.Wc(ntorch.cat([rnn_o, ec], "rnns")).tanh() # cat ent and type, this seems fine out = (self.Wc_nov(ntorch.cat([rnn_o, ec_E, ec_T], "rnns")) if self.noattnvalues else self.Wc( ntorch.cat([rnn_o, ec, ec_E, ec_T], "rnns"))).tanh() # add ent and typ #out = self.Wc(ntorch.cat([rnn_o, ec + ec_ET], "rnns")).tanh() else: out = [] self.ea = [] self.ta = [] self.a = [] ec_ETt = ntorch.zeros(N, self.r_emb_sz, names=("batch", "rnns")).to(emb.values.device) for t in range(T): ec_ETt = ec_ETt.rename("rnns", "x") inp = ntorch.cat([emb.get("time", t), ec_ETt], "x").repeat("time", 1) rnn_o, s = self.rnn(inp, s) rnn_o = rnn_o.get("time", 0) log_e, ea_Et, ec_Et = attn(rnn_o, eA, ue_info.mask) log_t, ea_Tt, ec_Tt = attn(rnn_o, tA, ut_info.mask) ec_ETt = ec_Et + ec_Tt le = log_e.rename("els", "e") lt = log_t.rename("els", "t") aw = (le + lt).exp() ect = aw.dot(("t", "e"), v2dx) out.append(ntorch.cat([rnn_o, ect, ec_Et, ec_Tt], "rnns")) self.ea.append(ea_Et.detach()) self.ta.append(ea_Tt.detach()) self.a.append(aw.detach()) out = self.Wc(ntorch.stack(out, "time")).tanh() # return unnormalized vocab return self.proj(self.drop(out)), s
def beam(self, src, trg, k, beam_len, num_candidates): batch_size = src.shape['batch'] out_dists = HypothesisMap( device=self.device) # map a hypothesis to distribution over words scores = HypothesisMap( keys=[trg[{ 'trgSeqlen': slice(0, 1) }]], vals=[ntorch.zeros(batch_size, names='batch')], device=self.device) # map a hypothesis to its score end = HypothesisMap( device=self.device) # special buffer for hyptothesis with <EOS> attn = [] EOS_IND = 3 hidden = self.encoder(src) # make predictions for l in range(beam_len or trg.shape['trgSeqlen'] - 1): new_scores = HypothesisMap(device=self.device) hyps = scores.get_topk(k) if l > 0 else scores for hyp, score in hyps.items(): inp = hyp[{'trgSeqlen': slice(l, l + 1)}] out, hidden = self.decoder(inp, hidden) out = out.log_softmax('logit') topk = out.topk('logit', k) for i in range(k): pred_prob = topk[0][{'logit': i, 'trgSeqlen': -1}] pred = topk[1][{'logit': i}] new_hyp = ntorch.cat([hyp, pred], 'trgSeqlen') if hyp in out_dists: out_dists[new_hyp] = ntorch.cat([out_dists[hyp], out], 'trgSeqlen') else: out_dists[new_hyp] = out if torch.any((pred[{'trgSeqlen': -1}] == EOS_IND).values): end[new_hyp] = score + pred_prob end[new_hyp].masked_fill_( pred[{ 'trgSeqlen': -1 }] != EOS_IND, -float('inf')) pred_prob.masked_fill_( pred[{ 'trgSeqlen': -1 }] == EOS_IND, -float('inf')) new_scores[new_hyp] = score + pred_prob scores = new_scores for hyp, score in end.items(): scores[hyp] = score best = scores.get_topk(num_candidates).keys out = [out_dists[k] for k in best] #store output if 'attn' in hidden: attn.append(hidden['attn']) #format predictions return ntorch.stack(out, 'candidates'), ntorch.cat(attn, dim='trgSeqlen')
def __call__(self, batch_text): return ntorch.stack( [model(batch_text) for model in self.models], "model").mean("model")
def __getitem__(self, key): ret = [] for b in range(key.shape['batch']): ret.append(self.key2val[str(key[{'batch': b}])]) return ntorch.stack(ret, 'batch').to(self.device)
def trajectories(self, N=100, dt=0.02): perimeter = self.params['perimeter'] T = self.params["T"] n = int(T / dt) mu, sigma, b = [ self.params[i] for i in ["mean_rotation", "std_dev_rotation", "std_dev_forward"] ] rotation_velocities = torch.tensor( np.random.normal(mu, sigma, size=(n, N))).float() forward_velocities = torch.tensor(np.random.rayleigh( b, size=(n, N))).float() positions = ntorch.zeros((n, 2, N), names=("t", "ax", "sample")) vs = torch.zeros((n, N)) angles = rotation_velocities directions = torch.zeros((n, 2, N)) vs[0] = self.params["v0"] theta = torch.rand(N) * 2 * np.pi directions[0] = unit_vector(theta) positions[{ "t": 0 }] = ntorch.tensor(self.scene.random(N), names=("sample", "ax")) for i in range(1, n): dist, phi = self.scene.closestWall(positions[{ "t": i - 1 }].values, directions[i - 1]) wall = (dist < perimeter) & (phi.abs() < np.pi / 2) angle_correction = torch.where( wall, phi.sign() * (np.pi / 2 - phi.abs()), torch.zeros_like(phi)) angles[i] += angle_correction vs[i] = torch.where( wall, (1 - self.params["velocity_reduction"]) * (vs[i - 1]), forward_velocities[i], ) positions[{ "t": i }] = (positions[{ "t": i - 1 }] + directions[i - 1] * vs[i] * dt) mat = rotation_matrix(angles[i] * dt) directions[i] = torch.einsum("ijk,jk->ik", mat, directions[i - 1]) idx = np.round(np.linspace( 0, n - 2, self.params["trajectory_length"])).astype(int) # idx = np.array(sorted(np.random.choice(np.arange(n), size=self.params["trajectory_length"], replace=False))) dphis = ntorch.tensor(angles[idx] * dt, names=("t", "sample")) velocities = ntorch.tensor(vs[idx], names=("t", "sample")) vel = ntorch.stack((velocities, dphis.cos(), dphis.sin()), "input") xs = ntorch.tensor(positions.values[idx], names=("t", "ax", "sample")) # xs0 = positions[{'t': 0}] xs0 = ntorch.tensor(self.scene.random(n=N), names=("sample", "ax")) hd = torch.atan2(directions[:, 1], directions[:, 0]) hd0 = ntorch.tensor(hd[0][None], names=("hd", "sample")) hd = ntorch.tensor(hd[idx + 1][None], names=("hd", "t", "sample")) xs = xs.transpose('sample', 't', 'ax') hd = hd.transpose('sample', 't', 'hd') vel = vel.transpose('sample', 't', 'input') xs0 = xs0.transpose('sample', 'ax') hd0 = hd0.transpose('sample', 'hd') return xs, hd, vel, xs0, hd0