def run_train_step(self, batch, train_global_state): self.model.train() batch = batch.to(self.device) logits = forward_batch_delegate( model=self.model, batch=batch, omit_label_ids=True, task_type=self.task.TASK_TYPE, )[0] loss = compute_loss_from_model_output( logits=logits, loss_criterion=self.loss_criterion, batch=batch, task_type=self.task.TASK_TYPE, ) loss = self.complex_backpropagate(loss) loss_val = loss.item() optim_step_grad_accum( optimizer_scheduler=self.optimizer_scheduler, train_global_state=train_global_state, gradient_accumulation_steps=self.train_schedule. gradient_accumulation_steps, ) self.log_writer.write_entry( "loss_train", { "epoch": train_global_state.epoch, "epoch_step": train_global_state.epoch_step, "global_step": train_global_state.global_step, "loss_val": loss_val, "pred_entropy": compute_pred_entropy_clean(logits) })
def run_train_step(self, batch_m_triplet, train_global_state: TrainGlobalState): llp_loss = self.compute_llp_loss( sup_batch_m=batch_m_triplet.sup, ) uda_loss = self.compute_uda_loss( unsup_orig_batch_m=batch_m_triplet.unsup_orig, unsup_aug_batch_m=batch_m_triplet.unsup_aug, ) loss = llp_loss + self.llpuda_params.uda_coeff * uda_loss loss = self.complex_backpropagate(loss) loss_val = loss.item() optim_step_grad_accum( optimizer_scheduler=self.optimizer_scheduler, train_global_state=train_global_state, gradient_accumulation_steps=self.train_schedule.gradient_accumulation_steps, ) # Update memory bank with torch.no_grad(): new_embedding = self.model.forward_batch(batch_m_triplet.sup.batch).embedding self.llp_state.big_m_tensor[batch_m_triplet.sup.metadata["example_id"]] = ( (1 - self.llp_params.llp_mem_bank_t) * self.llp_state.big_m_tensor[batch_m_triplet.sup.metadata["example_id"]] + self.llp_params.llp_mem_bank_t * new_embedding ) self.log_writer.write_entry("loss_train", { "epoch": train_global_state.epoch, "epoch_step": train_global_state.epoch_step, "global_step": train_global_state.global_step, "loss_val": loss_val, })
def run_train_step(self, batch_triplet, train_global_state: TrainGlobalState): self.model.train() example_count = len(batch_triplet.sup) sup_loss, sup_logits = uda_ops.sup_train_step( model=self.model, sup_batch=batch_triplet.sup[0].to(self.device), task=self.task, global_step=train_global_state.global_step, train_schedule=self.train_schedule, uda_params=self.uda_params, zlogger=self.log_writer, ) if self.uda_params.use_unsup: example_count += len(batch_triplet.unsup_orig[0]) unsup_loss, unsup_orig_logits, unsup_aug_logits = uda_ops.unsup_train_step( model=self.model, unsup_orig_batch=batch_triplet.unsup_orig[0].to(self.device), unsup_aug_batch=batch_triplet.unsup_aug[0].to(self.device), uda_params=self.uda_params, zlogger=self.log_writer, ) weighted_unsup_loss = self.uda_params.uda_coeff * unsup_loss loss = sup_loss + weighted_unsup_loss else: unsup_orig_logits, unsup_aug_logits = None, None weighted_unsup_loss = 0 loss = sup_loss loss = self.complex_backpropagate(loss) optim_step_grad_accum( optimizer_scheduler=self.optimizer_scheduler, train_global_state=train_global_state, gradient_accumulation_steps=self.train_schedule. gradient_accumulation_steps, ) log_data = { "epoch": train_global_state.epoch, "epoch_step": train_global_state.epoch_step, "global_step": train_global_state.global_step, "sup": get_val(sup_loss), "unsup": get_val(weighted_unsup_loss), "total": get_val(loss), "sup_pred_entropy": compute_pred_entropy_clean(sup_logits), } if self.uda_params.use_unsup: log_data["unsup_orig_pred_entropy"] = compute_pred_entropy_clean( unsup_orig_logits) log_data["unsup_aug_pred_entropy"] = compute_pred_entropy_clean( unsup_aug_logits) self.log_writer.write_entry("loss_train", log_data) self.log_writer.flush()
def run_train_step(self, batch, batch_metadata, train_global_state: TrainGlobalState): self.model.train() batch = batch.to(self.device) loss, loss_details, model_output = self.compute_representation_loss( batch, batch_metadata) loss = self.complex_backpropagate(loss) loss_val = loss.item() optim_step_grad_accum( optimizer_scheduler=self.optimizer_scheduler, train_global_state=train_global_state, gradient_accumulation_steps=self.train_schedule. gradient_accumulation_steps, ) # Update memory bank with torch.no_grad(): new_embedding = self.model.forward_batch(batch).embedding self.llp_state.big_m_tensor[batch_metadata["example_id"]] = ( (1 - self.llp_params.llp_mem_bank_t) * self.llp_state.big_m_tensor[batch_metadata["example_id"]] + self.llp_params.llp_mem_bank_t * new_embedding) loss_details_logged = { k: v.mean().item() for k, v in loss_details.items() } self.log_writer.write_entry( "loss_train", combine_dicts([{ "epoch": train_global_state.epoch, "epoch_step": train_global_state.epoch_step, "global_step": train_global_state.global_step, "loss_val": loss_val, "pred_entropy": torch_utils.compute_pred_entropy_clean(model_output.logits) }, loss_details_logged])) self.log_writer.flush() return loss_details
def run_train_step(self, batch_duplet: TrainDataDuplet, train_global_state: TrainGlobalState): self.model.train() self.teacher_model_wrapper.model.train() sup_batch = batch_duplet.sup.to(self.device) # Classification [SUP] sup_logits = forward_batch_delegate( model=self.model, batch=sup_batch.batch, omit_label_ids=True, task_type=self.task.TASK_TYPE, )[0] classification_loss = compute_loss_from_model_output( logits=sup_logits, loss_criterion=self.loss_criterion, batch=sup_batch.batch, task_type=self.task.TASK_TYPE, ) # Consistency with torch.no_grad(): teacher_sup_logits = forward_batch_delegate( model=self.teacher_model_wrapper.model, batch=sup_batch.batch, omit_label_ids=True, task_type=self.task.TASK_TYPE, )[0] # Consistency if self.mt_params.use_unsup: unsup_batch = batch_duplet.unsup.to(self.device) unsup_logits = forward_batch_delegate( model=self.model, batch=unsup_batch.batch, omit_label_ids=True, task_type=self.task.TASK_TYPE, )[0] teacher_unsup_logits = forward_batch_delegate( model=self.teacher_model_wrapper.model, batch=unsup_batch.batch, omit_label_ids=True, task_type=self.task.TASK_TYPE, )[0] student_logits = torch.cat([sup_logits, unsup_logits], dim=0) teacher_logits = torch.cat([teacher_sup_logits, teacher_unsup_logits], dim=0) else: student_logits = sup_logits teacher_logits = teacher_sup_logits raw_consistency_loss = compute_raw_consistency_loss( student_logits=student_logits, teacher_logits=teacher_logits, mt_params=self.mt_params, ) consistency_weight = get_current_consistency_weight( global_step=train_global_state.global_step, mt_params=self.mt_params, ) consistency_loss = consistency_weight * raw_consistency_loss # Combine loss = classification_loss + consistency_loss loss = self.complex_backpropagate(loss) optim_step_grad_accum( optimizer_scheduler=self.optimizer_scheduler, train_global_state=train_global_state, gradient_accumulation_steps=self.train_schedule.gradient_accumulation_steps, ) update_teacher( student_wrapper=self.model_wrapper, teacher_wrapper=self.teacher_model_wrapper, alpha=self.mt_params.alpha, global_step=train_global_state.global_step, ) self.log_writer.write_entry("loss_train", { "epoch": train_global_state.epoch, "epoch_step": train_global_state.epoch_step, "global_step": train_global_state.global_step, "classification_loss": classification_loss.item(), "consistency_loss": consistency_loss.item(), "total_loss": loss.item(), "pred_entropy": compute_pred_entropy_clean(sup_logits) })