def _handle_batch(self, batch: Mapping[str, Any]) -> None: self.output = { "generated_b": utils.get_nn_from_ddp_module(self.model)[ "generator_ab" ](batch["real_a"]), "generated_a": utils.get_nn_from_ddp_module(self.model)[ "generator_ba" ](batch["real_b"]), } self.output["reconstructed_a"] = utils.get_nn_from_ddp_module( self.model )["generator_ba"](self.output["generated_b"]) self.output["reconstructed_b"] = utils.get_nn_from_ddp_module( self.model )["generator_ab"](self.output["generated_a"])
def on_batch_end(self, runner: "IRunner") -> None: """ On batch end action. Args: runner: runner """ identical_b = utils.get_nn_from_ddp_module( runner.model)["generator_ab"](runner.input["real_b"]) identical_a = utils.get_nn_from_ddp_module(runner.model)[self.ba_key]( runner.input["real_a"]) loss_id_b = runner.criterion["identical"](identical_b, runner.input["real_b"]) loss_id_a = runner.criterion["identical"](identical_a, runner.input["real_a"]) loss_id = self.lambda_a * loss_id_a + self.lambda_b * loss_id_b runner.batch_metrics["identical_loss"] = loss_id
def on_batch_end(self, runner: "IRunner") -> None: """ On batch end action. Args: runner: runner """ loss_a = runner.criterion["gan"]( inp=utils.get_nn_from_ddp_module(runner.model)["discriminator_b"]( runner.output["generated_b"]), is_real=True, ) loss_b = runner.criterion["gan"]( inp=utils.get_nn_from_ddp_module(runner.model)["discriminator_a"]( runner.output["generated_a"]), is_real=True, ) runner.batch_metrics["gan_loss"] = (self.lambda_a * loss_a + self.lambda_b * loss_b)
def _get_loss(self, manifold: str, runner: "CycleGANRunner") -> torch.Tensor: discriminator = utils.get_nn_from_ddp_module( runner.model)[f"discriminator_{manifold}"] pred_real = discriminator(runner.input[f"real_{manifold}"]) loss_real = runner.criterion["gan"](pred_real, True) generated = runner.buffers[manifold].get( runner.output[f"generated_{manifold}"]) pred_generated = discriminator(generated.detach()) loss_generated = runner.criterion["gan"](pred_generated, False) return (loss_generated + loss_real) / 2
def set_requires_grad(self, model_keys: List[str], req: bool) -> None: """ Setting requires grad value for specified models. Args: model_keys: models to be set req: value to set """ for key in model_keys: for param in utils.get_nn_from_ddp_module(self.model)[ key ].parameters(): param.requires_grad = req
def _handle_batch(self, batch: Mapping[str, Any]) -> None: self.set_requires_grad([self.teacher_key], False) generated_a, hiddens_s = utils.get_nn_from_ddp_module(self.model)[ self.student_key ](batch["real_b"], True) self.output = { "generated_b": utils.get_nn_from_ddp_module(self.model)[ "generator_ab" ](batch["real_a"]), "generated_a": generated_a, "hiddens_s": hiddens_s, } self.output["reconstructed_a"] = utils.get_nn_from_ddp_module( self.model )[self.student_key](self.output["generated_b"]) self.output["reconstructed_b"] = utils.get_nn_from_ddp_module( self.model )["generator_ab"](self.output["generated_a"]) with torch.no_grad(): generated, hiddens_t = utils.get_nn_from_ddp_module(self.model)[ "generator_ba" ](batch["real_b"], True) self.output["hiddens_t"] = hiddens_t self.output["generated_t"] = generated