def loss_on_batch(self, annotated_insts: List, loss_factor=1., training=True, **kwargs): self.refresh_batch(training) # -- sents: List[Sent] = list(yield_sents(annotated_insts)) # == # extend to events import numpy as np bsize = sum(len(z.events) for z in sents) mlen = max(len(z) for z in sents) arr_preds = np.full([bsize, mlen], 0., dtype=np.int32) arr_inputs = np.full([bsize, mlen], b'<pad>', dtype=object) arr_labels = np.full([bsize, mlen], b'<pad>', dtype=object) ii = 0 for sent in sents: for evt in sent.events: widx, wlen = evt.mention.get_span() assert wlen == 1 # -- arr_preds[ii, widx] = 1 arr_inputs[ii, :len(sent)] = [ s.lower().encode() for s in sent.seq_word.vals ] # -- tmp_labels = ["O"] * len(sent) for arg in evt.args: role = arg.role a_widx, a_wlen = arg.arg.mention.get_span() a_labs = ["B-" + role] + ["I-" + role] * (a_wlen - 1) assert all(z == "O" for z in tmp_labels[a_widx:a_widx + a_wlen]) tmp_labels[a_widx:a_widx + a_wlen] = a_labs # -- arr_labels[ii, :len(sent)] = [z.encode() for z in tmp_labels] # -- ii += 1 assert ii == bsize features, labels = data.lookup(({ "preds": NpWarapper(arr_preds), "inputs": NpWarapper(arr_inputs) }, NpWarapper(arr_labels)), "train", self.params) # == final_loss = self.M(features, labels) info = { "inst": len(annotated_insts), "sent": len(sents), "fb": 1, "loss": final_loss.item() } if training: assert final_loss.requires_grad BK.backward(final_loss, loss_factor) zlog( f"batch shape = {len(annotated_insts)} {bsize} {mlen} {bsize*mlen}" ) return info
def loss_on_batch(self, ibatch: InputBatch, loss_factor=1., training=True, **kwargs): self.refresh_batch(training) self._mark_active(ibatch) # -- if self.ddp is None: final_loss, info = self.forward(ibatch) else: final_loss, info = self.ddp.forward(ibatch) # -- if training: # if BK.get_value(final_loss).item() >= 0: # note: loss should be >=0 usually!! if final_loss.requires_grad: # note: if requires_grad BK.backward(final_loss, loss_factor) else: # no need to backward if no loss assert self.ddp is None, "Cannot bypass backward in DDP mode!" info["fb0"] = 1 return info
def loss_on_batch(self, insts: List, loss_factor=1., training=True, force_lidx=None, **kwargs): conf: ZmtlModelConf = self.conf self.refresh_batch(training) # -- # import torch # torch.autograd.set_detect_anomaly(True) # -- actual_insts = list(self._yield_insts(insts)) med = self.med enc_cached_input = self.enc.prepare_inputs(actual_insts) # == # if needed, forward other models (can be self) aug_scores = {} with BK.no_grad_env(): if conf.aug_times >= 1: # forward all at once!! _mm_input = enc_cached_input if (conf.aug_times == 1) else self.enc.prepare_inputs(actual_insts*conf.aug_times) for mm in self.aug_models: # add them all to aug_scores!! mm.enc_forward(_mm_input, aug_scores, conf.aug_training_flag) # == self.refresh_batch(training) med.force_lidx = force_lidx # note: special assign # enc self.enc.forward(None, med, cached_input=enc_cached_input) # dec med.aug_scores = aug_scores # note: assign here!! all_losses = med.do_losses() # -- # final loss and backward info = {"inst0": len(insts), "inst": len(actual_insts), "fb": 1, "fb0": 0} final_loss, loss_info = self.collect_loss(all_losses, ret_dict=(self.pcgrad is not None)) info.update(loss_info) if training: if self.pcgrad is not None: # self.pcgrad.do_backward(self.parameters(), final_loss, loss_factor) # note: we only specially treat enc's, for others, grads will always be accumulated! self.pcgrad.do_backward(self.enc.parameters(), final_loss, loss_factor) else: # as usual # assert final_loss.requires_grad if BK.get_value(final_loss).item() > 0: # note: loss should be >0 usually!! BK.backward(final_loss, loss_factor) else: # no need to backwrad if no loss info["fb0"] = 1 med.restart() # clean! med.force_lidx = None # clear! return info
def loss_on_batch(self, annotated_insts: List, loss_factor=1., training=True, **kwargs): self.refresh_batch(training) # -- sents: List[Sent] = list(yield_sents(annotated_insts)) # emb and enc mask_expr, input_expr, enc_expr = self._emb_and_enc(sents) # frame f_loss = self.framer.loss(sents, enc_expr, mask_expr) # -- # final loss and backward info = {"inst": len(annotated_insts), "sent": len(sents), "fb": 1} final_loss, loss_info = self.collect_loss([f_loss]) info.update(loss_info) if training: assert final_loss.requires_grad BK.backward(final_loss, loss_factor) return info