Ejemplo n.º 1
0
 def init_llp_state(self,
                    train_examples,
                    verbose=True,
                    zero_out_unlabeled_confidence=True):
     self.llp_state = self.create_empty_llp_state(
         train_examples=train_examples)
     train_dataset_with_metadata = self.convert_examples_to_dataset(
         train_examples, verbose=True)
     train_dataloader = self.get_train_dataloader(
         train_dataset_with_metadata=train_dataset_with_metadata,
         use_eval_batch_size=True,
         do_override_labels=False,
         verbose=verbose,
     )
     self.populate_llp_state(train_dataloader=train_dataloader,
                             verbose=verbose)
     if zero_out_unlabeled_confidence:
         self.zero_out_unlabeled_confidence()
     self.log_writer.write_entry(
         "populate_logs",
         combine_dicts([
             populate_logs(llp_state=self.llp_state,
                           llp_params=self.llp_params),
             {
                 "epoch": -1,
             },
         ]))
     self.log_writer.flush()
Ejemplo n.º 2
0
 def run_train_epoch_context(self,
                             train_dataset_with_metadata,
                             train_global_state: TrainGlobalState,
                             populate_after=True,
                             verbose=True):
     train_dataloader = self.get_train_dataloader(
         train_dataset_with_metadata=train_dataset_with_metadata,
         do_override_labels=True,
         verbose=verbose,
     )
     for batch, batch_metadata in maybe_tqdm(train_dataloader,
                                             desc="Training",
                                             verbose=verbose):
         self.run_train_step(
             batch=batch,
             batch_metadata=batch_metadata,
             train_global_state=train_global_state,
         )
         yield batch, train_global_state
     if populate_after:
         self.populate_llp_state(
             train_dataloader=train_dataloader,
             verbose=verbose,
         )
         self.log_writer.write_entry(
             "populate_logs",
             combine_dicts([
                 populate_logs(llp_state=self.llp_state,
                               llp_params=self.llp_params),
                 {
                     "epoch": train_global_state.epoch,
                 },
             ]))
Ejemplo n.º 3
0
    def run_train_step(self, batch, batch_metadata,
                       train_global_state: TrainGlobalState):
        self.model.train()
        batch = batch.to(self.device)
        loss, loss_details, model_output = self.compute_representation_loss(
            batch, batch_metadata)
        loss = self.complex_backpropagate(loss)
        loss_val = loss.item()

        optim_step_grad_accum(
            optimizer_scheduler=self.optimizer_scheduler,
            train_global_state=train_global_state,
            gradient_accumulation_steps=self.train_schedule.
            gradient_accumulation_steps,
        )

        # Update memory bank
        with torch.no_grad():
            new_embedding = self.model.forward_batch(batch).embedding
        self.llp_state.big_m_tensor[batch_metadata["example_id"]] = (
            (1 - self.llp_params.llp_mem_bank_t) *
            self.llp_state.big_m_tensor[batch_metadata["example_id"]] +
            self.llp_params.llp_mem_bank_t * new_embedding)

        loss_details_logged = {
            k: v.mean().item()
            for k, v in loss_details.items()
        }
        self.log_writer.write_entry(
            "loss_train",
            combine_dicts([{
                "epoch":
                train_global_state.epoch,
                "epoch_step":
                train_global_state.epoch_step,
                "global_step":
                train_global_state.global_step,
                "loss_val":
                loss_val,
                "pred_entropy":
                torch_utils.compute_pred_entropy_clean(model_output.logits)
            }, loss_details_logged]))
        self.log_writer.flush()

        return loss_details
Ejemplo n.º 4
0
 def run_train_epoch_context(self, train_dataset_with_metadata, uda_task_data,
                             train_global_state: TrainGlobalState,
                             populate_after=True, verbose=True):
     self.model.train()
     sup_dataloader = self.get_sup_dataloader(
         train_dataset_with_metadata=train_dataset_with_metadata,
         do_override_labels=True, verbose=verbose,
     )
     unsup_dataloaders = self.get_unsup_dataloaders(
         sup_dataloader=sup_dataloader,
         uda_task_data=uda_task_data,
     )
     dataloader_triplet = self.form_dataloader_triplet(
         sup_dataloader=sup_dataloader,
         unsup_orig_loader=unsup_dataloaders.unsup_orig,
         unsup_aug_loader=unsup_dataloaders.unsup_aug,
     )
     train_iterator = enumerate(maybe_tqdm(zip(
         dataloader_triplet.sup,
         dataloader_triplet.unsup_orig,
         dataloader_triplet.unsup_aug
     ), total=len(dataloader_triplet.sup), desc="Training", verbose=verbose))
     for sup_batch_m, unsup_orig_batch_m, unsup_aug_batch_m in train_iterator:
         batch_m_triplet = uda_runner.TrainDataTriplet(
             sup=sup_batch_m.to(self.device),
             unsup_orig=unsup_orig_batch_m.to(self.device),
             unsup_aug=unsup_aug_batch_m.to(self.device),
         )
         self.run_train_step(
             batch_m_triplet=batch_m_triplet,
             train_global_state=train_global_state,
         )
         yield batch_m_triplet, train_global_state
     if populate_after:
         self.populate_llp_state(
             train_dataloader=sup_dataloader,
             verbose=verbose,
         )
         self.log_writer.write_entry("populate_logs", combine_dicts([
             llp_runner.populate_logs(llp_state=self.llp_state, llp_params=self.llp_params),
             {
                 "epoch": train_global_state.epoch,
             },
         ]))
Ejemplo n.º 5
0
    def forward(self,
                mask,
                y_hat,
                y_hat_from_masked_x,
                y,
                classifier_loss_from_masked_x,
                use_p,
                reduce=True):
        _, max_indexes = y_hat.detach().max(1)
        _, max_indexes_on_masked_x = y_hat_from_masked_x.detach().max(1)
        correct_on_clean = y.eq(max_indexes).long()
        metadata = {}

        mask_mean = F.avg_pool2d(mask, 224, stride=1).squeeze()
        metadata["mask_mean"] = mask_mean
        # Potentially rename to mask_size or mask_size_for_reg
        if self.add_prob_layers:
            metadata["use_p"] = use_p
            # adjust to minimize deviation from p
            mask_mean = (mask_mean - use_p)
            if self.prob_loss_func == "l1":
                mask_mean = mask_mean.abs()
            elif self.prob_loss_func == "l2":
                mask_mean = mask_mean.pow(2)
            else:
                raise KeyError(self.prob_loss_func)

        # apply regularization loss only on non-trivially confused images
        mask_reg, mask_reg_metadata = self.compute_mask_regularization(
            y=y,
            max_indexes_on_masked_x=max_indexes_on_masked_x,
            correct_on_clean=correct_on_clean,
            mask_mean=mask_mean,
            reduce=reduce,
        )
        tv_reg = tv_loss(mask=mask, tv_weight=self.lambda_tv)
        regularization = mask_reg + tv_reg

        loss, loss_metadata = self.compute_only_loss(
            y_hat_from_masked_x=y_hat_from_masked_x,
            y=y,
            classifier_loss_from_masked_x=classifier_loss_from_masked_x,
            correct_on_clean=correct_on_clean,
            reduce=reduce,
        )
        masker_loss = loss + regularization

        if reduce:
            metadata["regularization"] = regularization.mean()
            metadata["correct_on_clean"] = correct_on_clean.float().mean()

        else:
            metadata["regularization"] = regularization
            metadata["correct_on_clean"] = correct_on_clean.float()

        metadata["loss"] = loss
        metadata["tv_reg"] = tv_reg

        metadata = combine_dicts([metadata, mask_reg_metadata], strict=True)
        if self.objective_type == "entropy":
            metadata["negative_entropy"] = loss_metadata["negative_entropy"]
        return masker_loss, metadata