Example #1
0
    def run_train_step(self, batch, train_global_state):
        self.model.train()
        batch = batch.to(self.device)
        logits = forward_batch_delegate(
            model=self.model,
            batch=batch,
            omit_label_ids=True,
            task_type=self.task.TASK_TYPE,
        )[0]
        loss = compute_loss_from_model_output(
            logits=logits,
            loss_criterion=self.loss_criterion,
            batch=batch,
            task_type=self.task.TASK_TYPE,
        )

        loss = self.complex_backpropagate(loss)
        loss_val = loss.item()

        optim_step_grad_accum(
            optimizer_scheduler=self.optimizer_scheduler,
            train_global_state=train_global_state,
            gradient_accumulation_steps=self.train_schedule.
            gradient_accumulation_steps,
        )
        self.log_writer.write_entry(
            "loss_train", {
                "epoch": train_global_state.epoch,
                "epoch_step": train_global_state.epoch_step,
                "global_step": train_global_state.global_step,
                "loss_val": loss_val,
                "pred_entropy": compute_pred_entropy_clean(logits)
            })
Example #2
0
def run_val(val_examples, val_dataloader, model, task, loss_criterion, device,
            local_rank, verbose):
    if not local_rank == -1:
        return
    model.eval()
    total_eval_loss = 0
    nb_eval_steps, nb_eval_examples = 0, 0
    all_logits = []
    for step, (batch, batch_metadata) in enumerate(
            maybe_tqdm(val_dataloader,
                       desc="Evaluating (Val)",
                       verbose=verbose)):
        batch = batch.to(device)

        with torch.no_grad():
            logits = forward_batch_delegate(
                model=model,
                batch=batch,
                omit_label_ids=True,
                task_type=task.TASK_TYPE,
            )[0]
            tmp_eval_loss = compute_loss_from_model_output(
                logits=logits,
                loss_criterion=loss_criterion,
                batch=batch,
                task_type=task.TASK_TYPE,
            )

        logits = logits.detach().cpu().numpy()
        total_eval_loss += tmp_eval_loss.mean().item()

        nb_eval_examples += len(batch)
        nb_eval_steps += 1
        all_logits.append(logits)
    eval_loss = total_eval_loss / nb_eval_steps
    all_logits = np.concatenate(all_logits, axis=0)

    return {
        "logits": all_logits,
        "loss": eval_loss,
        "metrics": evaluate.compute_task_metrics(task, all_logits,
                                                 val_examples),
    }
Example #3
0
    def run_test(self, test_examples, verbose=True):
        test_dataloader = self.get_eval_dataloader(test_examples)
        self.model.eval()
        all_logits = []
        for step, (batch, batch_metadata) in enumerate(
                maybe_tqdm(test_dataloader, desc="Predictions (Test)", verbose=verbose)):
            batch = batch.to(self.device)
            with torch.no_grad():
                logits = forward_batch_delegate(
                    model=self.model,
                    batch=batch,
                    omit_label_ids=True,
                    task_type=self.task.TASK_TYPE,
                )[0]
            logits = logits.detach().cpu().numpy()
            all_logits.append(logits)

        all_logits = np.concatenate(all_logits, axis=0)
        return all_logits
Example #4
0
    def run_train_step(self, batch_duplet: TrainDataDuplet, train_global_state: TrainGlobalState):
        self.model.train()
        self.teacher_model_wrapper.model.train()

        sup_batch = batch_duplet.sup.to(self.device)

        # Classification [SUP]
        sup_logits = forward_batch_delegate(
            model=self.model,
            batch=sup_batch.batch,
            omit_label_ids=True,
            task_type=self.task.TASK_TYPE,
        )[0]
        classification_loss = compute_loss_from_model_output(
            logits=sup_logits,
            loss_criterion=self.loss_criterion,
            batch=sup_batch.batch,
            task_type=self.task.TASK_TYPE,
        )
        # Consistency
        with torch.no_grad():
            teacher_sup_logits = forward_batch_delegate(
                model=self.teacher_model_wrapper.model,
                batch=sup_batch.batch,
                omit_label_ids=True,
                task_type=self.task.TASK_TYPE,
            )[0]

        # Consistency
        if self.mt_params.use_unsup:
            unsup_batch = batch_duplet.unsup.to(self.device)
            unsup_logits = forward_batch_delegate(
                model=self.model,
                batch=unsup_batch.batch,
                omit_label_ids=True,
                task_type=self.task.TASK_TYPE,
            )[0]
            teacher_unsup_logits = forward_batch_delegate(
                model=self.teacher_model_wrapper.model,
                batch=unsup_batch.batch,
                omit_label_ids=True,
                task_type=self.task.TASK_TYPE,
            )[0]
            student_logits = torch.cat([sup_logits, unsup_logits], dim=0)
            teacher_logits = torch.cat([teacher_sup_logits, teacher_unsup_logits], dim=0)
        else:
            student_logits = sup_logits
            teacher_logits = teacher_sup_logits

        raw_consistency_loss = compute_raw_consistency_loss(
            student_logits=student_logits,
            teacher_logits=teacher_logits,
            mt_params=self.mt_params,
        )
        consistency_weight = get_current_consistency_weight(
            global_step=train_global_state.global_step,
            mt_params=self.mt_params,
        )
        consistency_loss = consistency_weight * raw_consistency_loss

        # Combine
        loss = classification_loss + consistency_loss
        loss = self.complex_backpropagate(loss)

        optim_step_grad_accum(
            optimizer_scheduler=self.optimizer_scheduler,
            train_global_state=train_global_state,
            gradient_accumulation_steps=self.train_schedule.gradient_accumulation_steps,
        )
        update_teacher(
            student_wrapper=self.model_wrapper,
            teacher_wrapper=self.teacher_model_wrapper,
            alpha=self.mt_params.alpha,
            global_step=train_global_state.global_step,
        )
        self.log_writer.write_entry("loss_train", {
            "epoch": train_global_state.epoch,
            "epoch_step": train_global_state.epoch_step,
            "global_step": train_global_state.global_step,
            "classification_loss": classification_loss.item(),
            "consistency_loss": consistency_loss.item(),
            "total_loss": loss.item(),
            "pred_entropy": compute_pred_entropy_clean(sup_logits)
        })