Example #1
0
 def fb_on_batch(self,
                 annotated_insts: List[ParseInstance],
                 training=True,
                 loss_factor=1.,
                 **kwargs):
     self.refresh_batch(training)
     # encode
     input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
         annotated_insts, training)
     mask_expr = BK.input_real(mask_arr)
     # the parsing loss
     arc_score = self.scorer_helper.score_arc(enc_repr)
     lab_score = self.scorer_helper.score_label(enc_repr)
     full_score = arc_score + lab_score
     parsing_loss, info = self._loss(annotated_insts, full_score, mask_expr)
     # other loss?
     jpos_loss = self.jpos_loss(jpos_pack, mask_expr)
     reg_loss = self.reg_scores_loss(arc_score, lab_score)
     #
     info["loss_parse"] = BK.get_value(parsing_loss).item()
     final_loss = parsing_loss
     if jpos_loss is not None:
         info["loss_jpos"] = BK.get_value(jpos_loss).item()
         final_loss = parsing_loss + jpos_loss
     if reg_loss is not None:
         final_loss = final_loss + reg_loss
     info["fb"] = 1
     if training:
         BK.backward(final_loss, loss_factor)
     return info
Example #2
0
 def fb_on_batch(self,
                 annotated_insts: List[ParseInstance],
                 training=True,
                 loss_factor=1.,
                 **kwargs):
     self.refresh_batch(training)
     # todo(note): here always using training lambdas
     full_score, original_scores, jpos_pack, mask_expr, valid_mask_d, _ = \
         self._score(annotated_insts, False, self.lambda_g1_arc_training, self.lambda_g1_lab_training)
     parsing_loss, info = self._loss(annotated_insts, full_score, mask_expr,
                                     valid_mask_d)
     # other loss?
     jpos_loss = self.jpos_loss(jpos_pack, mask_expr)
     reg_loss = self.reg_scores_loss(*original_scores)
     #
     info["loss_parse"] = BK.get_value(parsing_loss).item()
     final_loss = parsing_loss
     if jpos_loss is not None:
         info["loss_jpos"] = BK.get_value(jpos_loss).item()
         final_loss = parsing_loss + jpos_loss
     if reg_loss is not None:
         final_loss = final_loss + reg_loss
     info["fb"] = 1
     if training:
         BK.backward(final_loss, loss_factor)
     return info
Example #3
0
 def collect_loss_and_backward(self, loss_info_cols: List[Dict],
                               training: bool, loss_factor: float):
     final_loss_dict = LossHelper.combine_multiple(
         loss_info_cols)  # loss_name -> {}
     if len(final_loss_dict) <= 0:
         return {}  # no loss!
     final_losses = []
     ret_info_vals = OrderedDict()
     for loss_name, loss_info in final_loss_dict.items():
         final_losses.append(loss_info['sum'] / (loss_info['count'] + 1e-5))
         for k in loss_info.keys():
             one_item = loss_info[k]
             ret_info_vals[f"loss:{loss_name}_{k}"] = one_item.item(
             ) if hasattr(one_item, "item") else float(one_item)
     final_loss = BK.stack(final_losses).sum()
     if training and final_loss.requires_grad:
         BK.backward(final_loss, loss_factor)
     return ret_info_vals