Exemplo n.º 1
0
    def __init__(self, n_input: int, n_batch: int = 0, n_labels: int = 0,
                 n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1,
                 dropout_rate: float = 0.1, dispersion: str = "gene",
                 log_variational: bool = True, reconstruction_loss: str = "zinb",
                 y_prior=None, labels_groups: Sequence[int] = None, use_labels_groups: bool = False,
                 classifier_parameters: dict = dict()):
        super().__init__(n_input, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers,
                         dropout_rate=dropout_rate, n_batch=n_batch, dispersion=dispersion,
                         log_variational=log_variational, reconstruction_loss=reconstruction_loss)

        self.n_labels = n_labels
        # Classifier takes n_latent as input
        cls_parameters = {"n_layers": n_layers, "n_hidden": n_hidden, "dropout_rate": dropout_rate}
        cls_parameters.update(classifier_parameters)
        self.classifier = Classifier(n_latent, n_labels=n_labels, **cls_parameters)

        self.encoder_z2_z1 = Encoder(n_latent, n_latent, n_cat_list=[self.n_labels], n_layers=n_layers,
                                     n_hidden=n_hidden, dropout_rate=dropout_rate)
        self.decoder_z1_z2 = Decoder(n_latent, n_latent, n_cat_list=[self.n_labels], n_layers=n_layers,
                                     n_hidden=n_hidden)

        self.y_prior = torch.nn.Parameter(
            y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), requires_grad=False
        )
        self.use_labels_groups = use_labels_groups
        self.labels_groups = np.array(labels_groups) if labels_groups is not None else None
        if self.use_labels_groups:
            assert labels_groups is not None, "Specify label groups"
            unique_groups = np.unique(self.labels_groups)
            self.n_groups = len(unique_groups)
            assert (unique_groups == np.arange(self.n_groups)).all()
            self.classifier_groups = Classifier(n_latent, n_hidden, self.n_groups, n_layers, dropout_rate)
            self.groups_index = torch.nn.ParameterList([torch.nn.Parameter(
                torch.tensor((self.labels_groups == i).astype(np.uint8), dtype=torch.uint8), requires_grad=False
            ) for i in range(self.n_groups)])
Exemplo n.º 2
0
    def __init__(self,
                 n_input,
                 n_labels,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 n_batch=0,
                 y_prior=None,
                 use_cuda=False):
        super(SVAEC, self).__init__()
        self.n_labels = n_labels
        self.n_input = n_input

        self.y_prior = y_prior if y_prior is not None else (
            1 / self.n_labels) * torch.ones(self.n_labels)
        # Automatically desactivate if useless
        self.n_batch = 0 if n_batch == 1 else n_batch
        self.z_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=n_latent,
                                 n_layers=n_layers,
                                 dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=1,
                                 n_layers=1,
                                 dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_hidden=n_hidden,
                                   n_layers=n_layers,
                                   dropout_rate=dropout_rate,
                                   n_batch=n_batch)

        self.dispersion = 'gene'
        self.px_r = torch.nn.Parameter(torch.randn(n_input, ))

        # Classifier takes n_latent as input
        self.classifier = Classifier(n_latent, n_hidden, self.n_labels,
                                     n_layers, dropout_rate)
        self.encoder_z2_z1 = Encoder(n_input=n_latent,
                                     n_cat=self.n_labels,
                                     n_latent=n_latent,
                                     n_layers=n_layers)
        self.decoder_z1_z2 = Decoder(n_latent,
                                     n_latent,
                                     n_cat=self.n_labels,
                                     n_layers=n_layers)

        self.use_cuda = use_cuda and torch.cuda.is_available()
        if self.use_cuda:
            self.cuda()
            self.y_prior = self.y_prior.cuda()
Exemplo n.º 3
0
    def __init__(self,
                 n_input,
                 n_batch,
                 n_labels,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 y_prior=None,
                 logreg_classifier=False,
                 dispersion="gene",
                 log_variational=True,
                 reconstruction_loss="zinb"):
        super(SVAEC, self).__init__(n_input,
                                    n_hidden=n_hidden,
                                    n_latent=n_latent,
                                    n_layers=n_layers,
                                    dropout_rate=dropout_rate,
                                    n_batch=n_batch,
                                    dispersion=dispersion,
                                    log_variational=log_variational,
                                    reconstruction_loss=reconstruction_loss)

        self.n_labels = n_labels
        self.n_latent_layers = 2
        # Classifier takes n_latent as input
        if logreg_classifier:
            self.classifier = LinearLogRegClassifier(n_latent, self.n_labels)
        else:
            self.classifier = Classifier(n_latent, n_hidden, self.n_labels,
                                         n_layers, dropout_rate)

        self.encoder_z2_z1 = Encoder(n_latent,
                                     n_latent,
                                     n_cat_list=[self.n_labels],
                                     n_layers=n_layers,
                                     n_hidden=n_hidden,
                                     dropout_rate=dropout_rate)
        self.decoder_z1_z2 = Decoder(n_latent,
                                     n_latent,
                                     n_cat_list=[self.n_labels],
                                     n_layers=n_layers,
                                     n_hidden=n_hidden,
                                     dropout_rate=dropout_rate)

        self.y_prior = torch.nn.Parameter(y_prior if y_prior is not None else
                                          (1 / self.n_labels) *
                                          torch.ones(self.n_labels),
                                          requires_grad=False)
Exemplo n.º 4
0
    def __init__(self,
                 n_input,
                 n_batch,
                 n_labels,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 y_prior=None,
                 logreg_classifier=False,
                 dispersion="gene",
                 log_variational=True,
                 reconstruction_loss="zinb",
                 labels_groups=None,
                 use_labels_groups=False):
        super(SVAEC, self).__init__(n_input,
                                    n_hidden=n_hidden,
                                    n_latent=n_latent,
                                    n_layers=n_layers,
                                    dropout_rate=dropout_rate,
                                    n_batch=n_batch,
                                    dispersion=dispersion,
                                    log_variational=log_variational,
                                    reconstruction_loss=reconstruction_loss)

        self.n_labels = n_labels
        self.n_latent_layers = 2
        # Classifier takes n_latent as input
        if logreg_classifier:
            self.classifier = LinearLogRegClassifier(n_latent, self.n_labels)
        else:
            self.classifier = Classifier(n_latent, n_hidden, self.n_labels,
                                         n_layers, dropout_rate)

        self.encoder_z2_z1 = Encoder(n_latent,
                                     n_latent,
                                     n_cat_list=[self.n_labels],
                                     n_layers=n_layers,
                                     n_hidden=n_hidden,
                                     dropout_rate=dropout_rate)
        self.decoder_z1_z2 = Decoder(n_latent,
                                     n_latent,
                                     n_cat_list=[self.n_labels],
                                     n_layers=n_layers,
                                     n_hidden=n_hidden,
                                     dropout_rate=dropout_rate)

        self.y_prior = torch.nn.Parameter(y_prior if y_prior is not None else
                                          (1 / n_labels) *
                                          torch.ones(1, n_labels),
                                          requires_grad=False)
        self.use_labels_groups = use_labels_groups
        self.labels_groups = np.array(
            labels_groups) if labels_groups is not None else None
        if self.use_labels_groups:
            assert labels_groups is not None, "Specify label groups"
            unique_groups = np.unique(self.labels_groups)
            self.n_groups = len(unique_groups)
            assert (unique_groups == np.arange(self.n_groups)).all()
            self.classifier_groups = Classifier(n_latent, n_hidden,
                                                self.n_groups, n_layers,
                                                dropout_rate)
            self.groups_index = torch.nn.ParameterList([
                torch.nn.Parameter(torch.tensor(
                    (self.labels_groups == i).astype(np.uint8),
                    dtype=torch.uint8),
                                   requires_grad=False)
                for i in range(self.n_groups)
            ])