Exemple #1
0
    def run_train_step(self, batch, train_global_state):
        self.model.train()
        batch = batch.to(self.device)
        logits = forward_batch_delegate(
            model=self.model,
            batch=batch,
            omit_label_ids=True,
            task_type=self.task.TASK_TYPE,
        )[0]
        loss = compute_loss_from_model_output(
            logits=logits,
            loss_criterion=self.loss_criterion,
            batch=batch,
            task_type=self.task.TASK_TYPE,
        )

        loss = self.complex_backpropagate(loss)
        loss_val = loss.item()

        optim_step_grad_accum(
            optimizer_scheduler=self.optimizer_scheduler,
            train_global_state=train_global_state,
            gradient_accumulation_steps=self.train_schedule.
            gradient_accumulation_steps,
        )
        self.log_writer.write_entry(
            "loss_train", {
                "epoch": train_global_state.epoch,
                "epoch_step": train_global_state.epoch_step,
                "global_step": train_global_state.global_step,
                "loss_val": loss_val,
                "pred_entropy": compute_pred_entropy_clean(logits)
            })
Exemple #2
0
    def run_train_step(self, batch_m_triplet, train_global_state: TrainGlobalState):
        llp_loss = self.compute_llp_loss(
            sup_batch_m=batch_m_triplet.sup,
        )
        uda_loss = self.compute_uda_loss(
            unsup_orig_batch_m=batch_m_triplet.unsup_orig,
            unsup_aug_batch_m=batch_m_triplet.unsup_aug,
        )
        loss = llp_loss + self.llpuda_params.uda_coeff * uda_loss
        loss = self.complex_backpropagate(loss)
        loss_val = loss.item()

        optim_step_grad_accum(
            optimizer_scheduler=self.optimizer_scheduler,
            train_global_state=train_global_state,
            gradient_accumulation_steps=self.train_schedule.gradient_accumulation_steps,
        )

        # Update memory bank
        with torch.no_grad():
            new_embedding = self.model.forward_batch(batch_m_triplet.sup.batch).embedding
        self.llp_state.big_m_tensor[batch_m_triplet.sup.metadata["example_id"]] = (
            (1 - self.llp_params.llp_mem_bank_t)
            * self.llp_state.big_m_tensor[batch_m_triplet.sup.metadata["example_id"]]
            + self.llp_params.llp_mem_bank_t * new_embedding
        )

        self.log_writer.write_entry("loss_train", {
            "epoch": train_global_state.epoch,
            "epoch_step": train_global_state.epoch_step,
            "global_step": train_global_state.global_step,
            "loss_val": loss_val,
        })
Exemple #3
0
    def run_train_step(self, batch_triplet,
                       train_global_state: TrainGlobalState):
        self.model.train()
        example_count = len(batch_triplet.sup)
        sup_loss, sup_logits = uda_ops.sup_train_step(
            model=self.model,
            sup_batch=batch_triplet.sup[0].to(self.device),
            task=self.task,
            global_step=train_global_state.global_step,
            train_schedule=self.train_schedule,
            uda_params=self.uda_params,
            zlogger=self.log_writer,
        )
        if self.uda_params.use_unsup:
            example_count += len(batch_triplet.unsup_orig[0])
            unsup_loss, unsup_orig_logits, unsup_aug_logits = uda_ops.unsup_train_step(
                model=self.model,
                unsup_orig_batch=batch_triplet.unsup_orig[0].to(self.device),
                unsup_aug_batch=batch_triplet.unsup_aug[0].to(self.device),
                uda_params=self.uda_params,
                zlogger=self.log_writer,
            )
            weighted_unsup_loss = self.uda_params.uda_coeff * unsup_loss
            loss = sup_loss + weighted_unsup_loss
        else:
            unsup_orig_logits, unsup_aug_logits = None, None
            weighted_unsup_loss = 0
            loss = sup_loss
        loss = self.complex_backpropagate(loss)

        optim_step_grad_accum(
            optimizer_scheduler=self.optimizer_scheduler,
            train_global_state=train_global_state,
            gradient_accumulation_steps=self.train_schedule.
            gradient_accumulation_steps,
        )

        log_data = {
            "epoch": train_global_state.epoch,
            "epoch_step": train_global_state.epoch_step,
            "global_step": train_global_state.global_step,
            "sup": get_val(sup_loss),
            "unsup": get_val(weighted_unsup_loss),
            "total": get_val(loss),
            "sup_pred_entropy": compute_pred_entropy_clean(sup_logits),
        }
        if self.uda_params.use_unsup:
            log_data["unsup_orig_pred_entropy"] = compute_pred_entropy_clean(
                unsup_orig_logits)
            log_data["unsup_aug_pred_entropy"] = compute_pred_entropy_clean(
                unsup_aug_logits)
        self.log_writer.write_entry("loss_train", log_data)
        self.log_writer.flush()
Exemple #4
0
    def run_train_step(self, batch, batch_metadata,
                       train_global_state: TrainGlobalState):
        self.model.train()
        batch = batch.to(self.device)
        loss, loss_details, model_output = self.compute_representation_loss(
            batch, batch_metadata)
        loss = self.complex_backpropagate(loss)
        loss_val = loss.item()

        optim_step_grad_accum(
            optimizer_scheduler=self.optimizer_scheduler,
            train_global_state=train_global_state,
            gradient_accumulation_steps=self.train_schedule.
            gradient_accumulation_steps,
        )

        # Update memory bank
        with torch.no_grad():
            new_embedding = self.model.forward_batch(batch).embedding
        self.llp_state.big_m_tensor[batch_metadata["example_id"]] = (
            (1 - self.llp_params.llp_mem_bank_t) *
            self.llp_state.big_m_tensor[batch_metadata["example_id"]] +
            self.llp_params.llp_mem_bank_t * new_embedding)

        loss_details_logged = {
            k: v.mean().item()
            for k, v in loss_details.items()
        }
        self.log_writer.write_entry(
            "loss_train",
            combine_dicts([{
                "epoch":
                train_global_state.epoch,
                "epoch_step":
                train_global_state.epoch_step,
                "global_step":
                train_global_state.global_step,
                "loss_val":
                loss_val,
                "pred_entropy":
                torch_utils.compute_pred_entropy_clean(model_output.logits)
            }, loss_details_logged]))
        self.log_writer.flush()

        return loss_details
Exemple #5
0
    def run_train_step(self, batch_duplet: TrainDataDuplet, train_global_state: TrainGlobalState):
        self.model.train()
        self.teacher_model_wrapper.model.train()

        sup_batch = batch_duplet.sup.to(self.device)

        # Classification [SUP]
        sup_logits = forward_batch_delegate(
            model=self.model,
            batch=sup_batch.batch,
            omit_label_ids=True,
            task_type=self.task.TASK_TYPE,
        )[0]
        classification_loss = compute_loss_from_model_output(
            logits=sup_logits,
            loss_criterion=self.loss_criterion,
            batch=sup_batch.batch,
            task_type=self.task.TASK_TYPE,
        )
        # Consistency
        with torch.no_grad():
            teacher_sup_logits = forward_batch_delegate(
                model=self.teacher_model_wrapper.model,
                batch=sup_batch.batch,
                omit_label_ids=True,
                task_type=self.task.TASK_TYPE,
            )[0]

        # Consistency
        if self.mt_params.use_unsup:
            unsup_batch = batch_duplet.unsup.to(self.device)
            unsup_logits = forward_batch_delegate(
                model=self.model,
                batch=unsup_batch.batch,
                omit_label_ids=True,
                task_type=self.task.TASK_TYPE,
            )[0]
            teacher_unsup_logits = forward_batch_delegate(
                model=self.teacher_model_wrapper.model,
                batch=unsup_batch.batch,
                omit_label_ids=True,
                task_type=self.task.TASK_TYPE,
            )[0]
            student_logits = torch.cat([sup_logits, unsup_logits], dim=0)
            teacher_logits = torch.cat([teacher_sup_logits, teacher_unsup_logits], dim=0)
        else:
            student_logits = sup_logits
            teacher_logits = teacher_sup_logits

        raw_consistency_loss = compute_raw_consistency_loss(
            student_logits=student_logits,
            teacher_logits=teacher_logits,
            mt_params=self.mt_params,
        )
        consistency_weight = get_current_consistency_weight(
            global_step=train_global_state.global_step,
            mt_params=self.mt_params,
        )
        consistency_loss = consistency_weight * raw_consistency_loss

        # Combine
        loss = classification_loss + consistency_loss
        loss = self.complex_backpropagate(loss)

        optim_step_grad_accum(
            optimizer_scheduler=self.optimizer_scheduler,
            train_global_state=train_global_state,
            gradient_accumulation_steps=self.train_schedule.gradient_accumulation_steps,
        )
        update_teacher(
            student_wrapper=self.model_wrapper,
            teacher_wrapper=self.teacher_model_wrapper,
            alpha=self.mt_params.alpha,
            global_step=train_global_state.global_step,
        )
        self.log_writer.write_entry("loss_train", {
            "epoch": train_global_state.epoch,
            "epoch_step": train_global_state.epoch_step,
            "global_step": train_global_state.global_step,
            "classification_loss": classification_loss.item(),
            "consistency_loss": consistency_loss.item(),
            "total_loss": loss.item(),
            "pred_entropy": compute_pred_entropy_clean(sup_logits)
        })