def __init__(
     self,
     module: BaseModuleClass,
     n_obs_training,
     lr=1e-3,
     weight_decay=1e-6,
     n_steps_kl_warmup: Union[int, None] = None,
     n_epochs_kl_warmup: Union[int, None] = 400,
     reduce_lr_on_plateau: bool = False,
     lr_factor: float = 0.6,
     lr_patience: int = 30,
     lr_threshold: float = 0.0,
     lr_scheduler_metric: Literal[
         "elbo_validation", "reconstruction_loss_validation",
         "kl_local_validation"] = "elbo_validation",
     lr_min: float = 0,
     adversarial_classifier: Union[bool, Classifier] = False,
     scale_adversarial_loss: Union[float, Literal["auto"]] = "auto",
     **loss_kwargs,
 ):
     super().__init__(
         module=module,
         n_obs_training=n_obs_training,
         lr=lr,
         weight_decay=weight_decay,
         n_steps_kl_warmup=n_steps_kl_warmup,
         n_epochs_kl_warmup=n_epochs_kl_warmup,
         reduce_lr_on_plateau=reduce_lr_on_plateau,
         lr_factor=lr_factor,
         lr_patience=lr_patience,
         lr_threshold=lr_threshold,
         lr_scheduler_metric=lr_scheduler_metric,
         lr_min=lr_min,
     )
     if adversarial_classifier is True:
         self.n_output_classifier = self.module.n_batch
         self.adversarial_classifier = Classifier(
             n_input=self.module.n_latent,
             n_hidden=32,
             n_labels=self.n_output_classifier,
             n_layers=2,
             logits=True,
         )
     else:
         self.adversarial_classifier = adversarial_classifier
     self.scale_adversarial_loss = scale_adversarial_loss
Exemple #2
0
    def __init__(
        self,
        n_input,
        n_batch,
        n_labels,
        n_hidden=128,
        n_latent=10,
        n_layers=1,
        dropout_rate=0.1,
        y_prior=None,
        dispersion="gene",
        log_variational=True,
        gene_likelihood="zinb",
    ):
        super().__init__(
            n_input,
            n_batch,
            n_labels,
            n_hidden=n_hidden,
            n_latent=n_latent,
            n_layers=n_layers,
            dropout_rate=dropout_rate,
            dispersion=dispersion,
            log_variational=log_variational,
            gene_likelihood=gene_likelihood,
            use_observed_lib_size=False,
        )

        self.z_encoder = Encoder(
            n_input,
            n_latent,
            n_cat_list=[n_batch, n_labels],
            n_hidden=n_hidden,
            n_layers=n_layers,
            dropout_rate=dropout_rate,
        )
        self.decoder = DecoderSCVI(
            n_latent,
            n_input,
            n_cat_list=[n_batch, n_labels],
            n_layers=n_layers,
            n_hidden=n_hidden,
        )

        self.y_prior = torch.nn.Parameter(
            y_prior if y_prior is not None else
            (1 / n_labels) * torch.ones(1, n_labels),
            requires_grad=False,
        )

        self.classifier = Classifier(n_input,
                                     n_hidden,
                                     n_labels,
                                     n_layers=n_layers,
                                     dropout_rate=dropout_rate)
Exemple #3
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     if kwargs["adversarial_classifier"] is True:
         self.n_output_classifier = 2
         self.adversarial_classifier = Classifier(
             n_input=self.module.n_latent,
             n_hidden=32,
             n_labels=self.n_output_classifier,
             n_layers=3,
             logits=True,
         )
     else:
         self.adversarial_classifier = kwargs["adversarial_classifier"]
Exemple #4
0
    def __init__(
        self,
        adata: AnnData,
        **classifier_kwargs,
    ):
        # TODO, catch user warning here and logger warning
        # about non count data
        super().__init__(adata)

        self.module = Classifier(
            n_input=self.summary_stats.n_vars,
            n_labels=2,
            logits=True,
            **classifier_kwargs,
        )
        self._model_summary_string = "Solo model"
        self.init_params_ = self._get_init_params(locals())
Exemple #5
0
class AdversarialTrainingPlan(TrainingPlan):
    """
    Train vaes with adversarial loss option to encourage latent space mixing.

    Parameters
    ----------
    module
        A module instance from class ``BaseModuleClass``.
    lr
        Learning rate used for optimization :class:`~torch.optim.Adam`.
    weight_decay
        Weight decay used in :class:`~torch.optim.Adam`.
    n_steps_kl_warmup
        Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1.
        Only activated when `n_epochs_kl_warmup` is set to None.
    n_epochs_kl_warmup
        Number of epochs to scale weight on KL divergences from 0 to 1.
        Overrides `n_steps_kl_warmup` when both are not `None`.
    reduce_lr_on_plateau
        Whether to monitor validation loss and reduce learning rate when validation set
        `lr_scheduler_metric` plateaus.
    lr_factor
        Factor to reduce learning rate.
    lr_patience
        Number of epochs with no improvement after which learning rate will be reduced.
    lr_threshold
        Threshold for measuring the new optimum.
    lr_scheduler_metric
        Which metric to track for learning rate reduction.
    lr_min
        Minimum learning rate allowed
    adversarial_classifier
        Whether to use adversarial classifier in the latent space
    scale_adversarial_loss
        Scaling factor on the adversarial components of the loss.
        By default, adversarial loss is scaled from 1 to 0 following opposite of
        kl warmup.
    **loss_kwargs
        Keyword args to pass to the loss method of the `module`.
        `kl_weight` should not be passed here and is handled automatically.
    """
    def __init__(
        self,
        module: BaseModuleClass,
        lr=1e-3,
        weight_decay=1e-6,
        n_steps_kl_warmup: Union[int, None] = None,
        n_epochs_kl_warmup: Union[int, None] = 400,
        reduce_lr_on_plateau: bool = False,
        lr_factor: float = 0.6,
        lr_patience: int = 30,
        lr_threshold: float = 0.0,
        lr_scheduler_metric: Literal[
            "elbo_validation", "reconstruction_loss_validation",
            "kl_local_validation"] = "elbo_validation",
        lr_min: float = 0,
        adversarial_classifier: Union[bool, Classifier] = False,
        scale_adversarial_loss: Union[float, Literal["auto"]] = "auto",
        **loss_kwargs,
    ):
        super().__init__(
            module=module,
            lr=lr,
            weight_decay=weight_decay,
            n_steps_kl_warmup=n_steps_kl_warmup,
            n_epochs_kl_warmup=n_epochs_kl_warmup,
            reduce_lr_on_plateau=reduce_lr_on_plateau,
            lr_factor=lr_factor,
            lr_patience=lr_patience,
            lr_threshold=lr_threshold,
            lr_scheduler_metric=lr_scheduler_metric,
            lr_min=lr_min,
        )
        if adversarial_classifier is True:
            self.n_output_classifier = self.module.n_batch
            self.adversarial_classifier = Classifier(
                n_input=self.module.n_latent,
                n_hidden=32,
                n_labels=self.n_output_classifier,
                n_layers=2,
                logits=True,
            )
        else:
            self.adversarial_classifier = adversarial_classifier
        self.scale_adversarial_loss = scale_adversarial_loss

    def loss_adversarial_classifier(self,
                                    z,
                                    batch_index,
                                    predict_true_class=True):
        n_classes = self.n_output_classifier
        cls_logits = torch.nn.LogSoftmax(dim=1)(self.adversarial_classifier(z))

        if predict_true_class:
            cls_target = one_hot(batch_index, n_classes)
        else:
            one_hot_batch = one_hot(batch_index, n_classes)
            cls_target = torch.zeros_like(one_hot_batch)
            # place zeroes where true label is
            cls_target.masked_scatter_(
                ~one_hot_batch.bool(),
                torch.ones_like(one_hot_batch) / (n_classes - 1))

        l_soft = cls_logits * cls_target
        loss = -l_soft.sum(dim=1).mean()

        return loss

    def training_step(self, batch, batch_idx, optimizer_idx=0):
        kappa = (1 - self.kl_weight if self.scale_adversarial_loss == "auto"
                 else self.scale_adversarial_loss)
        batch_tensor = batch[_CONSTANTS.BATCH_KEY]
        if optimizer_idx == 0:
            loss_kwargs = dict(kl_weight=self.kl_weight)
            inference_outputs, _, scvi_loss = self.forward(
                batch, loss_kwargs=loss_kwargs)
            loss = scvi_loss.loss
            # fool classifier if doing adversarial training
            if kappa > 0 and self.adversarial_classifier is not False:
                z = inference_outputs["z"]
                fool_loss = self.loss_adversarial_classifier(
                    z, batch_tensor, False)
                loss += fool_loss * kappa

            reconstruction_loss = scvi_loss.reconstruction_loss
            self.log("train_loss", loss, on_epoch=True)
            return {
                "loss": loss,
                "reconstruction_loss_sum": reconstruction_loss.sum(),
                "kl_local_sum": scvi_loss.kl_local.sum(),
                "kl_global": scvi_loss.kl_global,
                "n_obs": reconstruction_loss.shape[0],
            }

        # train adversarial classifier
        # this condition will not be met unless self.adversarial_classifier is not False
        if optimizer_idx == 1:
            inference_inputs = self.module._get_inference_input(batch)
            outputs = self.module.inference(**inference_inputs)
            z = outputs["z"]
            loss = self.loss_adversarial_classifier(z.detach(), batch_tensor,
                                                    True)
            loss *= kappa

            return loss

    def training_epoch_end(self, outputs):
        # only report from optimizer one loss signature
        if self.adversarial_classifier:
            super().training_epoch_end(outputs[0])
        else:
            super().training_epoch_end(outputs)

    def configure_optimizers(self):
        params1 = filter(lambda p: p.requires_grad, self.module.parameters())
        optimizer1 = torch.optim.Adam(params1,
                                      lr=self.lr,
                                      eps=0.01,
                                      weight_decay=self.weight_decay)
        config1 = {"optimizer": optimizer1}
        if self.reduce_lr_on_plateau:
            scheduler1 = ReduceLROnPlateau(
                optimizer1,
                patience=self.lr_patience,
                factor=self.lr_factor,
                threshold=self.lr_threshold,
                min_lr=self.lr_min,
                threshold_mode="abs",
                verbose=True,
            )
            config1.update(
                {
                    "lr_scheduler": scheduler1,
                    "monitor": self.lr_scheduler_metric,
                }, )

        if self.adversarial_classifier is not False:
            params2 = filter(lambda p: p.requires_grad,
                             self.adversarial_classifier.parameters())
            optimizer2 = torch.optim.Adam(params2,
                                          lr=1e-3,
                                          eps=0.01,
                                          weight_decay=self.weight_decay)
            config2 = {"optimizer": optimizer2}

            # bug in pytorch lightning requires this way to return
            opts = [config1.pop("optimizer"), config2["optimizer"]]
            if "lr_scheduler" in config1:
                config1["scheduler"] = config1.pop("lr_scheduler")
                scheds = [config1]
                return opts, scheds
            else:
                return opts

        return config1