def eval_loop(model, dataloader, device=torch.device("cpu")): tto = q.ticktock("testing") tto.tick("testing") tt = q.ticktock("-") totaltestbats = len(dataloader) model.eval() epoch_reset(model) outs = [] with torch.no_grad(): for i, batch in enumerate(dataloader): batch = (batch, ) if not q.issequence(batch) else batch batch = q.recmap( batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x) batch_reset(model) modelouts = model(*batch) tt.live("eval - [{}/{}]".format(i + 1, totaltestbats)) if not q.issequence(modelouts): modelouts = (modelouts, ) if len(outs) == 0: outs = [[] for e in modelouts] for out_e, mout_e in zip(outs, modelouts): out_e.append(mout_e) ttmsg = "eval done" tt.stoplive() tt.tock(ttmsg) tto.tock("tested") ret = [torch.cat(out_e, 0) for out_e in outs] return ret
def forward(self, x, mask=None): fwd_ret = self.layer_fwd(x, mask=mask) rev_ret = self.layer_rev(x, mask=mask) merge_fn = (lambda a, b: torch.cat([a, b], -1) ) if self.mode == "cat" else (lambda a, b: a + b) if not q.issequence(fwd_ret): fwd_ret = [fwd_ret] if not q.issequence(rev_ret): rev_ret = [rev_ret] ret = tuple() if self._return_final: ret += (merge_fn(fwd_ret[0], rev_ret[0]), ) fwd_ret = fwd_ret[1:] rev_ret = rev_ret[1:] if self._return_all: ret += (merge_fn(fwd_ret[0], rev_ret[0]), ) if self._return_mask: ret += (mask, ) if len(ret) == 1: return ret[0] elif len(ret) == 0: print("no output specified") return else: return ret
def __init__(self, fn, register_params=None, register_modules=None): super(Lambda, self).__init__() self.fn = fn # optionally registers passed modules and params if register_modules is not None: if not q.issequence(register_modules): register_modules = [register_modules] self.extra_modules = q.ModuleList(register_modules) if register_params is not None: if not q.issequence(register_params): register_params = [register_params] self.extra_params = nn.ParameterList(register_params)
def forward(self, x, mask=None): x = self.dropout(x) if mask is not None: _x = torch.nn.utils.rnn.pack_padded_sequence(x, mask.sum(-1), batch_first=True, enforce_sorted=False) else: _x = x _outputs, hidden = self.rnn(_x) if mask is not None: y, _ = torch.nn.utils.rnn.pad_packed_sequence(_outputs, batch_first=True) else: y = _outputs hidden = (hidden, ) if not q.issequence(hidden) else hidden hiddens = [] for _hidden in hidden: i = 0 _hiddens = tuple() while i < _hidden.size(0): if self.bidir is True: _h = torch.cat([_hidden[i], _hidden[i + 1]], -1) i += 2 else: _h = _hidden[i] i += 1 _hiddens = _hiddens + (_h, ) hiddens.append(_hiddens) hiddens = tuple(zip(*hiddens)) return y, hiddens
def forward(self, *x, **kw): y_l = x args, kwargs = None, None argmapped = False for layer in self.layers: if argmapped: rargs = args rkw = {} rkw.update(kw) rkw.update(kwargs) else: rargs = y_l rkw = kw if isinstance(layer, q.argmap): args, kwargs = layer(rargs, rkw, self._saved_slots) argmapped = True elif isinstance(layer, argsave): globols = layer(rargs, rkw, self._saved_slots) else: y_l = layer(*rargs, **rkw) argmapped = False if not q.issequence(y_l) and not argmapped: y_l = tuple([y_l]) if argmapped: ret = args else: ret = y_l if len(ret) == 1: ret = ret[0] return ret
def get_entity_property(self, entities, property, language=None): if not q.issequence(entities): entities = [entities] entities = [fbfy(entity) for entity in entities] propertychain = [fbfy(p) for p in property.strip().split()] propchain = "" prevvar = "?s" varcount = 0 for prop in propertychain: newvar = "?var{}".format(varcount) varcount += 1 propchain += "{} {} {} .\n".format(prevvar, prop, newvar) prevvar = newvar propchain = propchain.replace(prevvar, "?o") query = """SELECT ?s ?o WHERE {{ {} VALUES ?s {{ {} }} {} }}""".format( propchain, " ".join(entities), "FILTER (lang(?o) = '{}')".format(language) if language is not None else "") ret = {} res = self._exec_query(query) results = res["results"]["bindings"] for result in results: s = unfbfy(result["s"]["value"]) if s not in ret: ret[s] = set() val = result["o"]["value"] if language is None: val = unfbfy(val) ret[s].add(val) return ret
def forward(self, x:State): if not "mstate" in x: x.mstate = State() mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 xlmrstates = self.xlmr.extract_features(inptensor) inpenc = xlmrstates final_enc = xlmrstates[:, 0, :] for i in range(len(self.enc_to_dec)): # iter over layers _fenc = self.enc_to_dec[i](final_enc) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_enc _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.training: out_mask = None else: out_mask = x.get_out_mask(device=enc.device) outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask) outs = (outs,) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1) return outs[0], x
def __init__(self, size_average=True, ignore_index=None, **kw): super(DiscreteLoss, self).__init__(**kw) if ignore_index is not None: if not q.issequence(ignore_index): self.ignore_indices = [ignore_index] else: self.ignore_indices = None self.size_average = size_average
def forward(self, x_t, ctx=None, ctx_mask=None, **kw): if ctx is None: ctx, ctx_mask = self._saved_ctx, self._saved_ctx_mask assert (ctx is not None) # if isinstance(self.out, q.rnn.AutoMaskedOut): # self.out.update(x_t) self.out.update(x_t) embs = self.emb(x_t) # embed input tokens if q.issequence(embs) and len(embs) == 2: # unpack if necessary embs, mask = embs if self.feed_att: if self._outvec_tm1 is None: assert (self.outvec_t0 is not None) #"h_hat_0 must be set when feed_att=True" self._outvec_tm1 = self.outvec_t0 core_inp = torch.cat([embs, self._outvec_tm1], 1) # append previous attention summary else: core_inp = embs core_out = self.core(core_inp) # feed through rnn # do normal attention over input alphas, summaries, scores = self.att(core_out, ctx, ctx_mask=ctx_mask, values=ctx) # do attention # do attention over decoded sequence if self.selfatt is not None and self.prev_coreouts is not None: selfalphas, selfsummaries, selfscores = self.selfatt(core_out, self.prev_coreouts) # do self-attention else: selfalphas, selfsummaries, selfscores = None, None, None # TODO ??? use self-attention summaries for output generation too? out_vec = self.merge(core_out, summaries, core_inp) out_vec = self.dropout(out_vec) self._outvec_tm1 = out_vec # store outvec (this is how Luong, 2015 does it) # save coreouts if self.selfatt is not None and self.prev_coreouts is not None: self.prev_coreouts = torch.cat([self.prev_coreouts, core_out.unsqueeze(1)], 1) else: self.prev_coreouts = core_out.unsqueeze(1) # introduce a sequence dimension ret = tuple() if self.out is None: ret += (out_vec,) else: _out_vec = self.out(out_vec, scores=scores, selfscores=selfscores) ret += (_out_vec,) # other returns if self.return_alphas: ret += (alphas,) if self.return_scores: ret += (scores,) if self.return_other: ret += (embs, core_out, summaries) return ret[0] if len(ret) == 1 else ret
def forward(self, x: State): if not "mstate" in x: x.mstate = State() mstate = x.mstate if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_enc = self.inp_enc(inpembs, mask) final_enc = final_enc.view(final_enc.size(0), -1).contiguous() final_enc = self.enc_to_dec(final_enc) mstate.ctx = inpenc mstate.ctx_mask = mask ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state( emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: mstate.prev_summ = torch.zeros_like(ctx[:, 0]) _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.nocopy is True: outs = self.out_lin(enc) else: outs = self.out_lin(enc, x.inp_tensor, scores) outs = (outs, ) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) x.stored_attentions = torch.cat( [x.stored_attentions, alphas.detach()[:, None, :]], 1) return outs[0], x
def forward( self, x, mask=None, init_states=None, reverse=False ): # (batsize, seqlen, indim), (batsize, seqlen), [(batsize, hdim)] batsize = x.size(0) if init_states is not None: if not q.issequence(init_states): init_states = (init_states, ) self.cell.set_init_states(*init_states) self.cell.reset_state() mask = mask if mask is not None else x.mask if hasattr( x, "mask") else None y_list = [] y_tm1 = None y_t = None i = x.size(1) while i > 0: t = i - 1 if reverse else x.size(1) - i mask_t = mask[:, t].unsqueeze(1) if mask is not None else None x_t = x[:, t] cellout = self.cell(x_t, mask_t=mask_t, t=t) y_t = cellout # mask # if mask_t is not None: # moved to cells (recBN is affected here) # if y_tm1 is None: # y_tm1 = q.var(torch.zeros(y_t.size())).cuda(crit=y_t).v # if x.is_cuda: y_tm1 = y_tm1.cuda() # y_t = y_t * mask_t + y_tm1 * (1 - mask_t) # y_tm1 = y_t if self._return_all: y_list.append(y_t) i -= 1 ret = tuple() if self._return_final: ret += (y_t, ) if self._return_all: if reverse: y_list.reverse() y = torch.stack(y_list, 1) ret += (y, ) if self._return_mask: ret += (mask, ) if len(ret) == 1: return ret[0] elif len(ret) == 0: print("no output specified") return else: return ret
def get_ignore_mask(gold, ignore_indices): if ignore_indices is not None and not q.issequence(ignore_indices): ignore_indices = [ignore_indices] mask = None # (batsize,) if ignore_indices is not None: for ignore in ignore_indices: mask_i = (gold != ignore) # zero for ignored ones if mask is None: mask = mask_i else: mask = mask & mask_i if mask is None: mask = torch.ones_like(gold).byte() return mask
def forward(self, model_outs, gold, **kw): if q.issequence(model_outs): x = model_outs[self.which] else: assert (self.which == 0) x = model_outs if self.reduction in ["elementwise_mean", "mean"]: ret = x.mean() elif self.reduction == "sum": ret = x.sum() else: ret = x return ret
def __call__(self, pred, gold, _numex=None, **kw): l = self.loss(pred, gold, **kw) if _numex is None: _numex = pred.size(0) if not q.issequence(pred) else pred[0].size( 0) if isinstance(l, tuple) and len(l) == 2: # loss returns numex too _numex = l[1] l = l[0] if isinstance(l, torch.Tensor): lp = l.item() else: lp = l self.epoch_agg_values.append(lp) self.epoch_agg_sizes.append(_numex) return l
def forward(self, *x): # TODO: multiple inputs and outputs x = [xe.contiguous() for xe in x] x0 = x[0] batsize, seqlen = x0.size(0), x0.size(1) i = [xe.view(batsize * seqlen, *xe.size()[2:]) for xe in x] y = self.block(*i) if not q.issequence(y): y = (y, ) yo = [] for ye in y: ye = ye.view(batsize, seqlen, *ye.size()[1:]) yo.append(ye) if len(yo) == 1: return yo[0] else: return tuple(yo)
def forward(self, x: State): if not "mstate" in x: x.mstate = State() mstate = x.mstate if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) inpenc, final_encs = self.inp_enc(inpembs, mask) init_states = [] for i in range(len(final_encs)): init_states.append(self.enc_to_dec[i](final_encs[i][0])) mstate.ctx = inpenc mstate.ctx_mask = mask ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state( emb.size(0), emb.device) init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.nocopy is True: outs = self.out_lin(enc) else: outs = self.out_lin(enc, x.inp_tensor, scores) outs = (outs, ) if not q.issequence(outs) else outs # _, preds = outs.max(-1) return outs[0], x
def forward(self, x_t, ctx=None, ctx_mask=None, **kw): assert (ctx is not None) embs = self.emb(x_t) if q.issequence(embs) and len(embs) == 2: embs, mask = embs if self.feed_att: if self._outvec_tm1 is None: assert (self.outvec_t0 is not None) #"h_hat_0 must be set when feed_att=True" self._outvec_tm1 = self.outvec_t0 core_inp = torch.cat([embs, self._outvec_tm1], 1) else: core_inp = embs prev_pushpop = self.get_pushpop_from(x_t) # THIS LINE IS ADDED core_out = self.core(core_inp) alphas, summaries, scores = self.att( core_out, ctx, ctx_mask=ctx_mask, values=ctx, prev_pushpop=prev_pushpop) # THIS LINE IS CHANGED out_vec = self.merge(core_out, summaries, core_inp) out_vec = self.dropout(out_vec) self._outvec_tm1 = out_vec # store outvec ret = tuple() if self.out is None: ret += (out_vec, ) else: _out_vec = self.out(out_vec) ret += (_out_vec, ) if self.return_alphas: ret += (alphas, ) if self.return_scores: ret += (scores, ) if self.return_other: ret += (embs, core_out, summaries) return ret[0] if len(ret) == 1 else ret
def forward(self, x_t, ctx=None, ctx_mask=None, **kw): assert (ctx is not None) if isinstance(self.out, q.rnn.AutoMaskedOut): self.out.update(x_t) embs = self.emb(x_t) # embed input tokens if q.issequence(embs) and len(embs) == 2: # unpack if necessary embs, mask = embs if self.feed_att: if self._outvec_tm1 is None: assert (self.outvec_t0 is not None) #"h_hat_0 must be set when feed_att=True" self._outvec_tm1 = self.outvec_t0 core_inp = torch.cat([embs, self._outvec_tm1], 1) # append previous attention summary else: core_inp = embs prev_pushpop = self.get_pushpop_from(x_t) # THIS LINE IS ADDED core_out = self.core(core_inp, prev_pushpop=prev_pushpop) # feed through rnn # THIS LINE IS CHANGED alphas, summaries, scores = self.att(core_out, ctx, ctx_mask=ctx_mask, values=ctx) # do attention out_vec = self.merge(core_out, summaries, core_inp) out_vec = self.dropout(out_vec) self._outvec_tm1 = out_vec # store outvec (this is how Luong, 2015 does it) ret = tuple() if self.out is None: ret += (out_vec,) else: _out_vec = self.out(out_vec) ret += (_out_vec,) # other returns if self.return_alphas: ret += (alphas,) if self.return_scores: ret += (scores,) if self.return_other: ret += (embs, core_out, summaries) return ret[0] if len(ret) == 1 else ret
def forward(self, x: BasicStateBatch): if "ctx" not in x.batched_states: # encode input inptensor = x.batched_states["inp_tensor"] mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_enc = self.inp_enc(inpembs, mask) final_enc = final_enc.view(final_enc.size(0), -1).contiguous() final_enc = self.enc_to_dec(final_enc) x.batched_states["ctx"] = inpenc x.batched_states["ctx_mask"] = mask ctx = x.batched_states["ctx"] ctx_mask = x.batched_states["ctx_mask"] emb = self.out_emb(x.batched_states["prev_token"]) if "rnn" not in x.batched_states: init_rnn_state = self.out_rnn.get_init_state( emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc x.batched_states["rnn"] = init_rnn_state # DONE: concat previous attention summary to emb if "prev_summ" not in x.batched_states: x.batched_states["prev_summ"] = torch.zeros_like(ctx[:, 0]) _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, x.batched_states["prev_summ"]], 1) enc = self.out_rnn(_emb, x.batched_states["rnn"]) alphas, summ, scores = self.att(enc, ctx, ctx_mask) x.batched_states["prev_summ"] = summ enc = torch.cat([enc, summ], -1) outs = self.out_lin(enc, x, scores) outs = (outs, ) if not q.issequence(outs) else outs return outs[0], x
def forward(self, x_t, ctx=None, ctx_mask=None, **kw): if ctx is None: ctx, ctx_mask = self._saved_ctx, self._saved_ctx_mask assert (ctx is not None) if self.out is not None and hasattr(self.out, "update"): self.out.update(x_t) # update output layer with current input embs = self.emb(x_t) # embed input tokens if q.issequence(embs) and len(embs) == 2: # unpack if necessary embs, mask = embs if self.feed_att: if self._outvec_tm1 is None: assert (self.outvec_t0 is not None) #"h_hat_0 must be set when feed_att=True" self._outvec_tm1 = self.outvec_t0 core_inp = torch.cat([embs, self._outvec_tm1], 1) # append previous attention summary else: core_inp = embs core_out = self.core(core_inp) # feed through rnn alphas, summaries, scores = self.att(core_out, ctx, ctx_mask=ctx_mask, values=ctx) # do attention out_vec = self.merge(core_out, summaries, core_inp) out_vec = self.dropout(out_vec) self._outvec_tm1 = out_vec # store outvec (this is how Luong, 2015 does it) if self.out is None: ret_normal = out_vec else: if isinstance(self.out, PointerGeneratorOut): _out_vec = self.out(out_vec, scores=scores) else: _out_vec = self.out(out_vec) ret_normal = _out_vec l = locals() ret = tuple([l[k] for k in sum(self.returns, [])]) return ret[0] if len(ret) == 1 else ret
def train_batch(batch=None, model=None, optim=None, losses=None, device=torch.device("cpu"), batch_number=-1, max_batches=0, current_epoch=0, max_epochs=0, on_start=tuple(), on_before_optim_step=tuple(), on_after_optim_step=tuple(), on_end=tuple()): """ Runs a single batch of SGD on provided batch and settings. :param batch: batch to run on :param model: torch.nn.Module of the model :param optim: torch optimizer :param losses: list of losswrappers :param device: device :param batch_number: which batch :param max_batches: total number of batches :param current_epoch: current epoch :param max_epochs: total number of epochs :param on_start: collection of functions to call when starting training batch :param on_before_optim_step: collection of functions for before optimization step is taken (gradclip) :param on_after_optim_step: collection of functions for after optimization step is taken :param on_end: collection of functions to call when batch is done :return: """ [e() for e in on_start] optim.zero_grad() model.train() batch = (batch, ) if not q.issequence(batch) else batch batch = q.recmap( batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x) numex = batch[0].size(0) if q.no_gold(losses): batch_in = batch gold = None else: batch_in = batch[:-1] gold = batch[-1] q.batch_reset(model) modelouts = model(*batch_in) trainlosses = [] for loss_obj in losses: loss_val = loss_obj(modelouts, gold, _numex=numex) loss_val = [loss_val] if not q.issequence(loss_val) else loss_val trainlosses.extend(loss_val) cost = trainlosses[0] # penalties penalties = 0 for loss_obj, trainloss in zip(losses, trainlosses): if isinstance(loss_obj.loss, q.loss.PenaltyGetter): penalties += trainloss cost = cost + penalties if torch.isnan(cost).any(): print("Cost is NaN!") embed() cost.backward() [e() for e in on_before_optim_step] optim.step() [e() for e in on_after_optim_step] ttmsg = "train - Epoch {}/{} - [{}/{}]: {}".format( current_epoch + 1, max_epochs, batch_number + 1, max_batches, q.pp_epoch_losses(*losses), ) [e() for e in on_end] return ttmsg
def train_batch_distill(batch=None, model=None, optim=None, losses=None, device=torch.device("cpu"), batch_number=-1, max_batches=0, current_epoch=0, max_epochs=0, on_start=tuple(), on_before_optim_step=tuple(), on_after_optim_step=tuple(), on_end=tuple(), run=False, mbase=None, goldgetter=None): """ Runs a single batch of SGD on provided batch and settings. :param _batch: batch to run on :param model: torch.nn.Module of the model :param optim: torch optimizer :param losses: list of losswrappers :param device: device :param batch_number: which batch :param max_batches: total number of batches :param current_epoch: current epoch :param max_epochs: total number of epochs :param on_start: collection of functions to call when starting training batch :param on_before_optim_step: collection of functions for before optimization step is taken (gradclip) :param on_after_optim_step: collection of functions for after optimization step is taken :param on_end: collection of functions to call when batch is done :param mbase: base model where to distill from. takes inputs and produces output distributions to match by student model. if goldgetter is specified, this is not used. :param goldgetter: takes the gold and produces a softgold :return: """ # if run is False: # kwargs = locals().copy() # return partial(train_batch, **kwargs) [e() for e in on_start] optim.zero_grad() model.train() batch = (batch, ) if not q.issequence(batch) else batch batch = q.recmap( batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x) batch_in = batch[:-1] gold = batch[-1] # run batch_in through teacher model to get teacher output distributions if goldgetter is not None: softgold = goldgetter(gold) elif mbase is not None: mbase.eval() q.batch_reset(mbase) with torch.no_grad(): softgold = mbase(*batch_in) else: raise q.SumTingWongException( "goldgetter and mbase can not both be None") q.batch_reset(model) modelouts = model(*batch_in) trainlosses = [] for loss_obj in losses: loss_val = loss_obj(modelouts, (softgold, gold)) loss_val = [loss_val] if not q.issequence(loss_val) else loss_val trainlosses.extend(loss_val) cost = trainlosses[0] cost.backward() [e() for e in on_before_optim_step] optim.step() [e() for e in on_after_optim_step] ttmsg = "train - Epoch {}/{} - [{}/{}]: {}".format( current_epoch + 1, max_epochs, batch_number + 1, max_batches, q.pp_epoch_losses(*losses), ) [e() for e in on_end] return ttmsg
def forward(self, x:State): if not "mstate" in x: x.mstate = State() x.mstate.decoding_step = torch.zeros(x.inp_tensor.size(0), dtype=torch.long, device=x.inp_tensor.device) mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_encs = self.inp_enc(inpembs, mask) for i, final_enc in enumerate(final_encs): # iter over layers _fenc = self.enc_to_dec[i](final_enc[0]) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask if self.training and q.v(self.beta) < 1: # sample one of the orders golds = x._gold_tensors goldsmask = (golds != 0).any(-1).float() numgolds = goldsmask.sum(-1) gold_select_prob = torch.ones_like(goldsmask) * goldsmask / numgolds[:, None] selector = gold_select_prob.multinomial(1)[:, 0] gold = golds.gather(1, selector[:, None, None].repeat(1, 1, golds.size(2)))[:, 0] # interpolate with original gold original_gold = x.gold_tensor beta_selector = (torch.rand_like(numgolds) <= q.v(self.beta)).long() gold_ = original_gold * beta_selector[:, None] + gold * (1 - beta_selector[:, None]) x.gold_tensor = gold_ ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.training: out_mask = None else: out_mask = x.get_out_mask(device=enc.device) if self.nocopy is True: outs = self.out_lin(enc, out_mask) else: outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask) outs = (outs,) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1) mstate.decoding_step = mstate.decoding_step + 1 return outs[0], x
def forward(self, x:State): if not "mstate" in x: x.mstate = State() mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_encs = self.inp_enc(inpembs, mask) for i, final_enc in enumerate(final_encs): # iter over layers _fenc = self.enc_to_dec[i](final_enc[0]) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state # ONR stuff: !!! assumes LISP style queries with parentheses as separate tokens and only parentheses opening and closing clauses stack_actions = torch.zeros_like(x.prev_actions) stack_actions += (x.prev_actions == self.open_id).long() * +1 stack_actions += (x.prev_actions == self.close_id).long() * -1 if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] _emb = emb if self.feedatt == True: _ctx = mstate.prev_summ else: _ctx = torch.zeros(_emb.size(0), 0, device=_emb.device) enc, new_rnnstate = self.out_rnn(_emb, _ctx, stack_actions, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.nocopy is True: outs = self.out_lin(enc) else: outs = self.out_lin(enc, x.inp_tensor, scores) outs = (outs,) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1) return outs[0], x
def forward(self, x: State): if not "mstate" in x: x.mstate = State() mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_encs = self.inp_enc(inpembs, mask) for i, final_enc in enumerate(final_encs): # iter over layers _fenc = self.enc_to_dec[i](final_enc[0]) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state( emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate if "prevstates" not in mstate: _ctx = ctx _ctx_mask = ctx_mask mstate.prevstates = enc[:, None, :] else: _ctx = torch.cat([ctx, mstate.prevstates], 1) _ctx_mask = torch.cat([ ctx_mask, torch.ones(mstate.prevstates.size(0), mstate.prevstates.size(1), dtype=ctx_mask.dtype, device=ctx_mask.device) ], 1) mstate.prevstates = torch.cat([mstate.prevstates, enc[:, None, :]], 1) alphas, summ, scores = self.att(enc, _ctx, _ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.training: out_mask = None else: out_mask = x.get_out_mask(device=enc.device) if self.nocopy is True: outs = self.out_lin(enc, out_mask) else: outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask) outs = (outs, ) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) atts = q.pad_tensors( [x.stored_attentions, alphas.detach()[:, None, :]], 2, 0) x.stored_attentions = torch.cat(atts, 1) return outs[0], x
def test_epoch(model=None, dataloader=None, losses=None, device=torch.device("cpu"), current_epoch=0, max_epochs=0, print_every_batch=False, on_start=tuple(), on_start_batch=tuple(), on_end_batch=tuple(), on_end=tuple()): """ Performs a test epoch. If run=True, runs, otherwise returns partially filled function. :param model: :param dataloader: :param losses: :param device: :param current_epoch: :param max_epochs: :param on_start: :param on_start_batch: :param on_end_batch: :param on_end: :return: """ tt = q.ticktock("-") model.eval() q.epoch_reset(model) [e() for e in on_start] with torch.no_grad(): for loss_obj in losses: loss_obj.push_epoch_to_history() loss_obj.reset_agg() loss_obj.loss.to(device) for i, _batch in enumerate(dataloader): [e() for e in on_start_batch] _batch = (_batch, ) if not q.issequence(_batch) else _batch _batch = q.recmap( _batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x) batch = _batch numex = batch[0].size(0) if q.no_gold(losses): batch_in = batch gold = None else: batch_in = batch[:-1] gold = batch[-1] q.batch_reset(model) modelouts = model(*batch_in) testlosses = [] for loss_obj in losses: loss_val = loss_obj(modelouts, gold, _numex=numex) loss_val = [loss_val ] if not q.issequence(loss_val) else loss_val testlosses.extend(loss_val) ttmsg = "test - Epoch {}/{} - [{}/{}]: {}".format( current_epoch + 1, max_epochs, i + 1, len(dataloader), q.pp_epoch_losses(*losses)) if print_every_batch: tt.msg(ttmsg) else: tt.live(ttmsg) [e() for e in on_end_batch] tt.stoplive() [e() for e in on_end] ttmsg = q.pp_epoch_losses(*losses) return ttmsg
def forward(self, x:State): if not "mstate" in x: x.mstate = State() x.mstate.decoding_step = torch.zeros(x.inp_tensor.size(0), dtype=torch.long, device=x.inp_tensor.device) mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_encs = self.inp_enc(inpembs, mask) for i, final_enc in enumerate(final_encs): # iter over layers _fenc = self.enc_to_dec[i](final_enc[0]) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask if not "outenc" in mstate: if self.training: outtensor = x.gold_tensor omask = outtensor != 0 outembs = self.out_emb_vae(outtensor) finalenc, _ = self.out_enc(outembs, omask) finalenc, _ = (finalenc + torch.log(omask.float()[:, :, None])).max(1) # max pool # reparam mu = self.out_mu(finalenc) logvar = self.out_logvar(finalenc) std = torch.exp(.5*logvar) eps = torch.randn_like(std) outenc = mu + eps * std mstate.outenc = outenc kld = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) kld = torch.sum(kld.clamp_min(self.minkl), -1) mstate.kld = kld ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] if self.training: outenc = mstate.outenc # outenc = outenc.gather(1, mstate.decoding_step[:, None, None].repeat(1, 1, outenc.size(2)))[:, 0] else: outenc = torch.randn(emb.size(0), self.zdim, device=emb.device) _emb = torch.cat([emb, outenc], 1) if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.training: out_mask = None else: out_mask = x.get_out_mask(device=enc.device) if self.nocopy is True: outs = self.out_lin(enc, out_mask) else: outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask) outs = (outs,) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1) mstate.decoding_step = mstate.decoding_step + 1 return outs[0], x