コード例 #1
0
    def training_step(self, batch: TernarySample, batch_idx: int) -> None:
        assert isinstance(batch.x, Tensor)
        opt = cast(Optimizer, self.optimizers())

        opt.zero_grad()

        model_out: ModelOut = self.forward(batch.x)
        loss_adv, loss_clf, loss = self._get_losses(model_out=model_out,
                                                    batch=batch)

        logging_dict = {
            "adv_loss": loss_adv.item(),
            "clf_loss": loss_clf.item(),
            "loss": loss.item(),
        }
        logging_dict = prefix_keys(dict_=logging_dict, prefix="train", sep="/")
        self.log_dict(logging_dict)

        compute_proj_grads(model=self.enc,
                           loss_p=loss_clf,
                           loss_a=loss_adv,
                           alpha=1.0)
        compute_grad(model=self.adv, loss=loss_adv)
        compute_grad(model=self.clf, loss=loss_clf)

        opt.step()

        if (self.lr_sched_interval is TrainingMode.step) and (
                self.global_step % self.lr_sched_freq == 0):
            sch = cast(LRScheduler, self.lr_schedulers())
            sch.step()
        if (self.lr_sched_interval is
                TrainingMode.epoch) and self.trainer.is_last_batch:
            sch = cast(LRScheduler, self.lr_schedulers())
            sch.step()
コード例 #2
0
 def training_step(self, batch: TernarySample, batch_idx: int, optimizer_idx: int) -> Tensor:
     assert isinstance(batch.x, Tensor)
     if optimizer_idx == 0:
         # Main model update
         self.set_requires_grad(self.adv, requires_grad=False)
         model_out = self.forward(x=batch.x, s=batch.s)
         laftr_loss = self._loss_laftr(y_pred=model_out.y, recon=model_out.x, batch=batch)
         adv_loss = self._loss_adv(s_pred=model_out.s, batch=batch)
         _acc = accuracy(y_pred=model_out.y, y_true=batch.y)
         logging_dict = {
             "loss": (laftr_loss + adv_loss).item(),
             "model_loss": laftr_loss.item(),
             "acc": _acc,
         }
         loss = laftr_loss + adv_loss
     elif optimizer_idx == 1:
         # Adversarial update
         self.set_requires_grad([self.enc, self.dec, self.clf], requires_grad=False)
         self.set_requires_grad(self.adv, requires_grad=True)
         model_out = self.forward(x=batch.x, s=batch.s)
         adv_loss = self._loss_adv(s_pred=model_out.s, batch=batch)
         laftr_loss = self._loss_laftr(y_pred=model_out.y, recon=model_out.x, batch=batch)
         target = batch.y.view(-1).long()
         _acc = self.train_acc(model_out.y.argmax(-1), target)
         logging_dict = {
             "loss": (laftr_loss + adv_loss).item(),
             "adv_loss": adv_loss.item(),
             "acc": _acc,
         }
         loss = -(laftr_loss + adv_loss)
     else:
         raise RuntimeError("There should only be 2 optimizers, but 3rd received.")
     logging_dict = prefix_keys(dict_=logging_dict, prefix="train", sep="/")
     self.log_dict(logging_dict)
     return loss
コード例 #3
0
 def test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
     results_dict = self.inference_epoch_end(outputs=outputs,
                                             stage=Stage.test)
     results_dict = prefix_keys(dict_=results_dict,
                                prefix=str(Stage.test),
                                sep="/")
     self.log_dict(results_dict)
コード例 #4
0
 def _get_loss(
     self,
     *,
     l_pos: Tensor,
     l_neg: Tensor,
 ) -> Tuple[Tensor, Dict[str, Tensor]]:
     loss, log_dict = self._inst_disc_loss(l_pos=l_pos, l_neg=l_neg)
     log_dict = prefix_keys(dict_=log_dict, prefix="inst_disc", sep="/")
     return loss, log_dict
コード例 #5
0
    def training_step(self, batch: BinarySample, batch_idx: int) -> Tensor:
        assert isinstance(batch.x, Tensor)
        logits = self.forward(batch.x)
        loss = self._get_loss(logits=logits, batch=batch)
        results_dict = {
            "loss": loss.item(),
            "acc": accuracy(y_pred=logits, y_true=batch.y),
        }
        results_dict = prefix_keys(dict_=results_dict,
                                   prefix=str(Stage.fit),
                                   sep="/")
        self.log_dict(results_dict)

        return loss
コード例 #6
0
    def training_step(self, batch: TernarySample, batch_idx: int) -> Tensor:
        assert isinstance(batch.x, Tensor)
        model_out: ModelOut = self.forward(batch.x)
        loss_adv, loss_clf, loss = self._get_losses(model_out=model_out, batch=batch)

        logging_dict = {
            f"{Stage.fit}/adv_loss": loss_adv.item(),
            f"{Stage.fit}": loss_clf.item(),
            f"{Stage.fit}/loss": loss.item(),
        }
        logging_dict = prefix_keys(dict_=logging_dict, prefix="train", sep="/")
        self.log_dict(logging_dict)

        return loss
コード例 #7
0
    def inference_step(self, batch: TernarySample, *, stage: Stage) -> STEP_OUTPUT:
        assert isinstance(batch.x, Tensor)
        model_out = self.forward(x=batch.x, s=batch.s)
        logging_dict = {
            "laftr_loss": self._loss_laftr(y_pred=model_out.y, recon=model_out.x, batch=batch),
            "adv_loss": self._loss_adv(s_pred=model_out.s, batch=batch),
        }
        logging_dict = prefix_keys(dict_=logging_dict, prefix=str(stage), sep="/")
        self.log_dict(logging_dict)

        return {
            "targets": batch.y.view(-1),
            "subgroup_inf": batch.s.view(-1),
            "logits_y": model_out.y,
        }
コード例 #8
0
    def inference_step(self, batch: TernarySample, *, stage: Stage) -> STEP_OUTPUT:
        assert isinstance(batch.x, Tensor)
        model_out: ModelOut = self.forward(batch.x)
        loss_adv, loss_clf, loss = self._get_losses(model_out=model_out, batch=batch)
        logging_dict = {
            "loss": loss.item(),
            "loss_adv": loss_adv.item(),
            "loss_clf": loss_clf.item(),
        }
        logging_dict = prefix_keys(dict_=logging_dict, prefix=str(stage), sep="/")
        self.log_dict(logging_dict)

        return {
            "targets": batch.y.view(-1),
            "subgroup_inf": batch.s.view(-1),
            "logits": model_out.y,
        }