Пример #1
0
    def __init__(
        self,
        n_input,
        n_batch,
        n_labels,
        n_hidden=128,
        n_latent=10,
        n_layers=1,
        dropout_rate=0.1,
        y_prior=None,
        dispersion="gene",
        log_variational=True,
        gene_likelihood="zinb",
    ):
        super().__init__(
            n_input,
            n_batch,
            n_labels,
            n_hidden=n_hidden,
            n_latent=n_latent,
            n_layers=n_layers,
            dropout_rate=dropout_rate,
            dispersion=dispersion,
            log_variational=log_variational,
            gene_likelihood=gene_likelihood,
            use_observed_lib_size=False,
        )

        self.z_encoder = Encoder(
            n_input,
            n_latent,
            n_cat_list=[n_batch, n_labels],
            n_hidden=n_hidden,
            n_layers=n_layers,
            dropout_rate=dropout_rate,
        )
        self.decoder = DecoderSCVI(
            n_latent,
            n_input,
            n_cat_list=[n_batch, n_labels],
            n_layers=n_layers,
            n_hidden=n_hidden,
        )

        self.y_prior = torch.nn.Parameter(
            y_prior if y_prior is not None else
            (1 / n_labels) * torch.ones(1, n_labels),
            requires_grad=False,
        )

        self.classifier = Classifier(n_input,
                                     n_hidden,
                                     n_labels,
                                     n_layers=n_layers,
                                     dropout_rate=dropout_rate)
Пример #2
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     if kwargs["adversarial_classifier"] is True:
         self.n_output_classifier = 2
         self.adversarial_classifier = Classifier(
             n_input=self.module.n_latent,
             n_hidden=32,
             n_labels=self.n_output_classifier,
             n_layers=3,
             logits=True,
         )
     else:
         self.adversarial_classifier = kwargs["adversarial_classifier"]
Пример #3
0
 def __init__(
     self,
     vae_model: BaseModuleClass,
     n_obs_training,
     lr=1e-3,
     weight_decay=1e-6,
     n_steps_kl_warmup: Union[int, None] = None,
     n_epochs_kl_warmup: Union[int, None] = 400,
     reduce_lr_on_plateau: bool = False,
     lr_factor: float = 0.6,
     lr_patience: int = 30,
     lr_threshold: float = 0.0,
     lr_scheduler_metric: Literal[
         "elbo_validation", "reconstruction_loss_validation",
         "kl_local_validation"] = "elbo_validation",
     lr_min: float = 0,
     adversarial_classifier: Union[bool, Classifier] = False,
     scale_adversarial_loss: Union[float, Literal["auto"]] = "auto",
     **loss_kwargs,
 ):
     super().__init__(
         vae_model=vae_model,
         n_obs_training=n_obs_training,
         lr=lr,
         weight_decay=weight_decay,
         n_steps_kl_warmup=n_steps_kl_warmup,
         n_epochs_kl_warmup=n_epochs_kl_warmup,
         reduce_lr_on_plateau=reduce_lr_on_plateau,
         lr_factor=lr_factor,
         lr_patience=lr_patience,
         lr_threshold=lr_threshold,
         lr_scheduler_metric=lr_scheduler_metric,
         lr_min=lr_min,
     )
     if adversarial_classifier is True:
         self.n_output_classifier = self.module.n_batch
         self.adversarial_classifier = Classifier(
             n_input=self.module.n_latent,
             n_hidden=32,
             n_labels=self.n_output_classifier,
             n_layers=2,
             logits=True,
         )
     else:
         self.adversarial_classifier = adversarial_classifier
     self.scale_adversarial_loss = scale_adversarial_loss
Пример #4
0
    def __init__(
        self,
        adata: AnnData,
        **model_kwargs,
    ):
        # TODO, catch user warning here and logger warning
        # about non count data
        super().__init__(adata)

        self.module = Classifier(
            n_input=self.summary_stats["n_vars"],
            n_labels=2,
            logits=True,
            **model_kwargs,
        )
        self._model_summary_string = "Solo model"
        self.init_params_ = self._get_init_params(locals())