def __init__( self, module: BaseModuleClass, n_obs_training, lr=1e-3, weight_decay=1e-6, n_steps_kl_warmup: Union[int, None] = None, n_epochs_kl_warmup: Union[int, None] = 400, reduce_lr_on_plateau: bool = False, lr_factor: float = 0.6, lr_patience: int = 30, lr_threshold: float = 0.0, lr_scheduler_metric: Literal[ "elbo_validation", "reconstruction_loss_validation", "kl_local_validation"] = "elbo_validation", lr_min: float = 0, adversarial_classifier: Union[bool, Classifier] = False, scale_adversarial_loss: Union[float, Literal["auto"]] = "auto", **loss_kwargs, ): super().__init__( module=module, n_obs_training=n_obs_training, lr=lr, weight_decay=weight_decay, n_steps_kl_warmup=n_steps_kl_warmup, n_epochs_kl_warmup=n_epochs_kl_warmup, reduce_lr_on_plateau=reduce_lr_on_plateau, lr_factor=lr_factor, lr_patience=lr_patience, lr_threshold=lr_threshold, lr_scheduler_metric=lr_scheduler_metric, lr_min=lr_min, ) if adversarial_classifier is True: self.n_output_classifier = self.module.n_batch self.adversarial_classifier = Classifier( n_input=self.module.n_latent, n_hidden=32, n_labels=self.n_output_classifier, n_layers=2, logits=True, ) else: self.adversarial_classifier = adversarial_classifier self.scale_adversarial_loss = scale_adversarial_loss
def __init__( self, n_input, n_batch, n_labels, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, y_prior=None, dispersion="gene", log_variational=True, gene_likelihood="zinb", ): super().__init__( n_input, n_batch, n_labels, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, dispersion=dispersion, log_variational=log_variational, gene_likelihood=gene_likelihood, use_observed_lib_size=False, ) self.z_encoder = Encoder( n_input, n_latent, n_cat_list=[n_batch, n_labels], n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate, ) self.decoder = DecoderSCVI( n_latent, n_input, n_cat_list=[n_batch, n_labels], n_layers=n_layers, n_hidden=n_hidden, ) self.y_prior = torch.nn.Parameter( y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), requires_grad=False, ) self.classifier = Classifier(n_input, n_hidden, n_labels, n_layers=n_layers, dropout_rate=dropout_rate)
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if kwargs["adversarial_classifier"] is True: self.n_output_classifier = 2 self.adversarial_classifier = Classifier( n_input=self.module.n_latent, n_hidden=32, n_labels=self.n_output_classifier, n_layers=3, logits=True, ) else: self.adversarial_classifier = kwargs["adversarial_classifier"]
def __init__( self, adata: AnnData, **classifier_kwargs, ): # TODO, catch user warning here and logger warning # about non count data super().__init__(adata) self.module = Classifier( n_input=self.summary_stats.n_vars, n_labels=2, logits=True, **classifier_kwargs, ) self._model_summary_string = "Solo model" self.init_params_ = self._get_init_params(locals())
class AdversarialTrainingPlan(TrainingPlan): """ Train vaes with adversarial loss option to encourage latent space mixing. Parameters ---------- module A module instance from class ``BaseModuleClass``. lr Learning rate used for optimization :class:`~torch.optim.Adam`. weight_decay Weight decay used in :class:`~torch.optim.Adam`. n_steps_kl_warmup Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when `n_epochs_kl_warmup` is set to None. n_epochs_kl_warmup Number of epochs to scale weight on KL divergences from 0 to 1. Overrides `n_steps_kl_warmup` when both are not `None`. reduce_lr_on_plateau Whether to monitor validation loss and reduce learning rate when validation set `lr_scheduler_metric` plateaus. lr_factor Factor to reduce learning rate. lr_patience Number of epochs with no improvement after which learning rate will be reduced. lr_threshold Threshold for measuring the new optimum. lr_scheduler_metric Which metric to track for learning rate reduction. lr_min Minimum learning rate allowed adversarial_classifier Whether to use adversarial classifier in the latent space scale_adversarial_loss Scaling factor on the adversarial components of the loss. By default, adversarial loss is scaled from 1 to 0 following opposite of kl warmup. **loss_kwargs Keyword args to pass to the loss method of the `module`. `kl_weight` should not be passed here and is handled automatically. """ def __init__( self, module: BaseModuleClass, lr=1e-3, weight_decay=1e-6, n_steps_kl_warmup: Union[int, None] = None, n_epochs_kl_warmup: Union[int, None] = 400, reduce_lr_on_plateau: bool = False, lr_factor: float = 0.6, lr_patience: int = 30, lr_threshold: float = 0.0, lr_scheduler_metric: Literal[ "elbo_validation", "reconstruction_loss_validation", "kl_local_validation"] = "elbo_validation", lr_min: float = 0, adversarial_classifier: Union[bool, Classifier] = False, scale_adversarial_loss: Union[float, Literal["auto"]] = "auto", **loss_kwargs, ): super().__init__( module=module, lr=lr, weight_decay=weight_decay, n_steps_kl_warmup=n_steps_kl_warmup, n_epochs_kl_warmup=n_epochs_kl_warmup, reduce_lr_on_plateau=reduce_lr_on_plateau, lr_factor=lr_factor, lr_patience=lr_patience, lr_threshold=lr_threshold, lr_scheduler_metric=lr_scheduler_metric, lr_min=lr_min, ) if adversarial_classifier is True: self.n_output_classifier = self.module.n_batch self.adversarial_classifier = Classifier( n_input=self.module.n_latent, n_hidden=32, n_labels=self.n_output_classifier, n_layers=2, logits=True, ) else: self.adversarial_classifier = adversarial_classifier self.scale_adversarial_loss = scale_adversarial_loss def loss_adversarial_classifier(self, z, batch_index, predict_true_class=True): n_classes = self.n_output_classifier cls_logits = torch.nn.LogSoftmax(dim=1)(self.adversarial_classifier(z)) if predict_true_class: cls_target = one_hot(batch_index, n_classes) else: one_hot_batch = one_hot(batch_index, n_classes) cls_target = torch.zeros_like(one_hot_batch) # place zeroes where true label is cls_target.masked_scatter_( ~one_hot_batch.bool(), torch.ones_like(one_hot_batch) / (n_classes - 1)) l_soft = cls_logits * cls_target loss = -l_soft.sum(dim=1).mean() return loss def training_step(self, batch, batch_idx, optimizer_idx=0): kappa = (1 - self.kl_weight if self.scale_adversarial_loss == "auto" else self.scale_adversarial_loss) batch_tensor = batch[_CONSTANTS.BATCH_KEY] if optimizer_idx == 0: loss_kwargs = dict(kl_weight=self.kl_weight) inference_outputs, _, scvi_loss = self.forward( batch, loss_kwargs=loss_kwargs) loss = scvi_loss.loss # fool classifier if doing adversarial training if kappa > 0 and self.adversarial_classifier is not False: z = inference_outputs["z"] fool_loss = self.loss_adversarial_classifier( z, batch_tensor, False) loss += fool_loss * kappa reconstruction_loss = scvi_loss.reconstruction_loss self.log("train_loss", loss, on_epoch=True) return { "loss": loss, "reconstruction_loss_sum": reconstruction_loss.sum(), "kl_local_sum": scvi_loss.kl_local.sum(), "kl_global": scvi_loss.kl_global, "n_obs": reconstruction_loss.shape[0], } # train adversarial classifier # this condition will not be met unless self.adversarial_classifier is not False if optimizer_idx == 1: inference_inputs = self.module._get_inference_input(batch) outputs = self.module.inference(**inference_inputs) z = outputs["z"] loss = self.loss_adversarial_classifier(z.detach(), batch_tensor, True) loss *= kappa return loss def training_epoch_end(self, outputs): # only report from optimizer one loss signature if self.adversarial_classifier: super().training_epoch_end(outputs[0]) else: super().training_epoch_end(outputs) def configure_optimizers(self): params1 = filter(lambda p: p.requires_grad, self.module.parameters()) optimizer1 = torch.optim.Adam(params1, lr=self.lr, eps=0.01, weight_decay=self.weight_decay) config1 = {"optimizer": optimizer1} if self.reduce_lr_on_plateau: scheduler1 = ReduceLROnPlateau( optimizer1, patience=self.lr_patience, factor=self.lr_factor, threshold=self.lr_threshold, min_lr=self.lr_min, threshold_mode="abs", verbose=True, ) config1.update( { "lr_scheduler": scheduler1, "monitor": self.lr_scheduler_metric, }, ) if self.adversarial_classifier is not False: params2 = filter(lambda p: p.requires_grad, self.adversarial_classifier.parameters()) optimizer2 = torch.optim.Adam(params2, lr=1e-3, eps=0.01, weight_decay=self.weight_decay) config2 = {"optimizer": optimizer2} # bug in pytorch lightning requires this way to return opts = [config1.pop("optimizer"), config2["optimizer"]] if "lr_scheduler" in config1: config1["scheduler"] = config1.pop("lr_scheduler") scheds = [config1] return opts, scheds else: return opts return config1