Пример #1
0
    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
        n_obs: int = 1.0,
    ):
        # generative_outputs is a dict of the return value from `generative(...)`
        # assume that `n_obs` is the number of training data points
        p_x_c = generative_outputs["p_x_c"]
        gamma = generative_outputs["gamma"]

        # compute Q
        # take mean of number of cells and multiply by n_obs (instead of summing n)
        q_per_cell = torch.sum(gamma * -p_x_c, 1)

        # third term is log prob of prior terms in Q
        theta_log = F.log_softmax(self.theta_logit, dim=-1)
        theta_log_prior = Dirichlet(self.dirichlet_concentration)
        theta_log_prob = -theta_log_prior.log_prob(
            torch.exp(theta_log) + THETA_LOWER_BOUND)
        prior_log_prob = theta_log_prob
        delta_log_prior = Normal(self.delta_log_mean,
                                 self.delta_log_log_scale.exp().sqrt())
        delta_log_prob = torch.masked_select(
            delta_log_prior.log_prob(self.delta_log), (self.rho > 0))
        prior_log_prob += -torch.sum(delta_log_prob)

        loss = (torch.mean(q_per_cell) * n_obs + prior_log_prob) / n_obs

        return LossRecorder(loss, q_per_cell, torch.zeros_like(q_per_cell),
                            prior_log_prob)
Пример #2
0
    def select_action(self, obs):
        concentration, value = self.forward(obs)

        m = Dirichlet(concentration)

        action = m.sample()
        self.saved_actions.append(SavedAction(m.log_prob(action), value))
        return list(action.cpu().numpy())
Пример #3
0
 def test_dirichlet_shape(self):
     dist = Dirichlet(torch.Tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]]))
     self.assertEqual(dist._batch_shape, torch.Size((3,)))
     self.assertEqual(dist._event_shape, torch.Size((2,)))
     self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
     self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2)))
     self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
     self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
Пример #4
0
 def test_dirichlet_log_prob(self):
     num_samples = 10
     alpha = torch.exp(torch.randn(5))
     dist = Dirichlet(alpha)
     x = dist.sample((num_samples,))
     actual_log_prob = dist.log_prob(x)
     for i in range(num_samples):
         expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy())
         self.assertAlmostEqual(actual_log_prob[i], expected_log_prob, places=3)
Пример #5
0
 def log_prob(self, dist: Dirichlet, samples) -> torch.Tensor:
     return dist.log_prob(samples)
Пример #6
0
    def forward(self, sentences, sentence_length,
                input_conversation_length, target_sentences, decode=False):
        """
        Args:
            sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len]
            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]
        Return:
            decoder_outputs: (Variable, FloatTensor)
                - train: [batch_size, seq_len, vocab_size]
                - eval: [batch_size, seq_len]
        """
        batch_size = input_conversation_length.size(0)
        num_sentences = sentences.size(0) - batch_size
        max_len = input_conversation_length.data.max().item()

        # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size]
        # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size]
        encoder_outputs, encoder_hidden = self.encoder(sentences,
                                                       sentence_length)

        # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size]
        encoder_hidden = encoder_hidden.transpose(
            1, 0).contiguous().view(num_sentences + batch_size, -1)

        # pad and pack encoder_hidden
        start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                                        input_conversation_length[:-1] + 1)), 0)
        # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size]
        encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1)
                                      for s, l in zip(start.data.tolist(),
                                                      input_conversation_length.data.tolist())], 0)

        # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_inference = encoder_hidden[:, 1:, :]
        encoder_hidden_inference_flat = torch.cat(
            [encoder_hidden_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)])

        # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_input = encoder_hidden[:, :-1, :]

        # context_outputs: [batch_size, max_len, context_size]
        context_outputs, context_last_hidden = self.context_encoder(encoder_hidden_input,
                                                                    input_conversation_length)
        # flatten outputs
        # context_outputs: [num_sentences, context_size]
        context_outputs = torch.cat([context_outputs[i, :l, :]
                                     for i, l in enumerate(input_conversation_length.data)])

        alpha_prior = self.prior(context_outputs)
        eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))
        if not decode:
            alpha_posterior = self.posterior(
                context_outputs, encoder_hidden_inference_flat)

            # resample of dirichlet
            # z_sent = mu_posterior + torch.sqrt(var_posterior) * eps
            if torch.cuda.is_available():
                alpha_posterior = alpha_posterior.cpu()
            
            dirichlet_dist = Dirichlet(alpha_posterior)
            z_sent = dirichlet_dist.rsample()
            if torch.cuda.is_available():
                z_sent = to_var(z_sent)
                alpha_posterior = to_var(alpha_posterior)

            # this two variable log_q_zx and log_p_z is not necessary here
            # log_q_zx = normal_logpdf(z_sent, mu_posterior, var_posterior).sum()
            # log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum()
            # log_q_zx = dirichlet_logpdf(z_sent, alpha_posterior).sum()
            # log_p_z = dirichlet_logpdf(z_sent, alpha_prior).sum()
            # print(" ")
            log_q_zx = dirichlet_dist.log_prob(z_sent.cpu()).sum().cuda()
            log_p_z = Dirichlet(alpha_prior.cpu()).log_prob(z_sent.cpu()).sum().cuda()
            # print(log_q_zx.item(), " ", post_z.item())
            # print(log_p_z.item(), " ", prior_z.item())
            # kl_div: [num_sentneces]
            # kl_div = normal_kl_div(mu_posterior, var_posterior, mu_prior, var_prior)
            kl_div = dirichlet_kl_div(alpha_posterior, alpha_prior)
            kl_div = torch.sum(kl_div)
        else:
            # z_sent = mu_prior + torch.sqrt(var_prior) * eps
            if torch.cuda.is_available():
                alpha_prior = alpha_prior.cpu()
            dirichlet_dist = Dirichlet(alpha_prior)
            z_sent = dirichlet_dist.rsample()
            if torch.cuda.is_available():
                z_sent = z_sent.cuda()
                alpha_prior = alpha_prior.cuda()
            
            kl_div = None
            # log_p_z = dirichlet_logpdf(z_sent, mu_prior, var_prior).sum()
            log_p_z = dirichlet_logpdf(z_sent, alpha_prior).sum()
            log_q_zx = None
        
        self.z_sent = z_sent
        latent_context = torch.cat([context_outputs, z_sent], 1)
        decoder_init = self.context2decoder(latent_context)
        decoder_init = decoder_init.view(-1,
                                         self.decoder.num_layers,
                                         self.decoder.hidden_size)
        decoder_init = decoder_init.transpose(1, 0).contiguous()

        # train: [batch_size, seq_len, vocab_size]
        # eval: [batch_size, seq_len]
        if not decode:

            decoder_outputs = self.decoder(target_sentences,
                                           init_h=decoder_init,
                                           decode=decode)

            return decoder_outputs, kl_div, log_p_z, log_q_zx

        else:
            # prediction: [batch_size, beam_size, max_unroll]
            prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)

            return prediction, kl_div, log_p_z, log_q_zx
Пример #7
0
class CellTypeModel(AstirModel):
    """Class to perform statistical inference to assign cells to cell types.

    :param dset: the input gene expression dataframe
    :param random_seed: the random seed for parameter initialization, defaults to 1234
    :param dtype: the data type of parameters, should be the same as `dset`, defaults to
        torch.float64
    """
    def __init__(
            self,
            dset: Optional[SCDataset] = None,
            random_seed: int = 1234,
            dtype: torch.dtype = torch.float64,
            device: torch.device = torch.device("cpu"),
    ) -> None:
        super().__init__(dset, random_seed, dtype, device)

        if dset is not None:
            self._param_init()

    def _param_init(self) -> None:
        """Initializes parameters and design matrices."""
        if self._dset is None:
            raise Exception("the dataset is not provided")
        G = self._dset.get_n_features()
        C = self._dset.get_n_classes()

        self._recog = TypeRecognitionNet(self._dset.get_n_classes(),
                                         self._dset.get_n_features()).to(
                                             self._device, dtype=self._dtype)

        # Establish data
        self._data: Dict[str, torch.Tensor] = {
            # "log_alpha": torch.log(torch.ones(C + 1, dtype=self._dtype) / (C + 1)).to(
            #     self._device
            # ),
            "rho": self._dset.get_marker_mat().to(self._device),
        }

        self._alpha_prior = Dirichlet(
            torch.ones(C + 1, dtype=self._dtype).to(self._device) * (C + 1))

        # Initialize mu, log_delta
        delta_init_mean = torch.log(
            torch.log(torch.tensor(3.0, dtype=self._dtype))
        )  # the log of the log of this is the multiplier
        t = torch.distributions.Normal(
            # delta_init_mean.clone().detach().to(self._dtype),
            torch.tensor(0, dtype=self._dtype),
            torch.tensor(0.1, dtype=self._dtype),
        )
        log_delta_init = t.sample((G, C + 1))

        mu_init = torch.log(
            torch.tensor(self._dset.get_mu_init(),
                         dtype=self._dtype)).to(self._device)
        # mu_init = torch.log(self._dset.get_mu()).to(self._device)
        # mu_init = mu_init - (
        #     self._data["rho"] * torch.exp(log_delta_init).to(self._device)
        # ).mean(1)

        mu_init = mu_init.reshape(-1, 1)

        # Create initialization dictionary
        initializations = {
            "mu":
            mu_init,
            "log_sigma":
            torch.log(self._dset.get_sigma()).to(self._device),
            "log_delta":
            log_delta_init,
            "p":
            torch.zeros((G, C + 1), dtype=self._dtype, device=self._device),
            "alpha_logits":
            torch.ones(C + 1, dtype=self._dtype, device=self._device),
        }
        P = self._dset.get_design().shape[1]
        # Add additional columns of mu for anything in the design matrix
        initializations["mu"] = torch.cat(
            [
                initializations["mu"],
                torch.zeros(
                    (G, P - 1), dtype=self._dtype, device=self._device),
            ],
            1,
        )
        # Create trainable variables
        self._variables: Dict[str, torch.Tensor] = {}
        for (n, v) in initializations.items():
            self._variables[n] = Variable(v.clone()).to(self._device)
            self._variables[n].requires_grad = True

    def load_hdf5(self, hdf5_name: str) -> None:
        """Initializes Cell Type Model from a hdf5 file type

        :param hdf5_name: file path
        """
        self._assignment = pd.read_hdf(hdf5_name,
                                       "celltype_model/celltype_assignments")
        with h5py.File(hdf5_name, "r") as f:
            grp = f["celltype_model"]
            param = grp["parameters"]
            self._variables = {
                "mu": torch.tensor(np.array(param["mu"])),
                "log_sigma": torch.tensor(np.array(param["log_sigma"])),
                "log_delta": torch.tensor(np.array(param["log_delta"])),
                "p": torch.tensor(np.array(param["p"])),
                "alpha_logits": torch.tensor(np.array(param["alpha_logits"])),
            }
            self._data = {
                "rho": torch.tensor(np.array(param["rho"])),
            }
            self._losses = torch.tensor(np.array(grp["losses"]["losses"]))

            rec = grp["recog_net"]
            hidden1_W = torch.tensor(np.array(rec["hidden_1.weight"]))
            hidden2_W = torch.tensor(np.array(rec["hidden_2.weight"]))
            state_dict = {
                "hidden_1.weight": hidden1_W,
                "hidden_1.bias": torch.tensor(np.array(rec["hidden_1.bias"])),
                "hidden_2.weight": hidden2_W,
                "hidden_2.bias": torch.tensor(np.array(rec["hidden_2.bias"])),
            }
            state_dict = OrderedDict(state_dict)
            self._recog = TypeRecognitionNet(
                hidden2_W.shape[0] - 1, hidden1_W.shape[1],
                hidden1_W.shape[0]).to(device=self._device, dtype=self._dtype)
            self._recog.load_state_dict(state_dict)
            self._recog.eval()

    def _forward(self, Y: torch.Tensor, X: torch.Tensor,
                 design: torch.Tensor) -> torch.Tensor:
        """One forward pass.

        :param Y: a sample from the dataset
        :param X: normalized sample data
        :param design: the corresponding row of design matrix
        :return: the cost (elbo) of the current pass
        """
        if self._dset is None:
            raise Exception("the dataset is not provided")
        G = self._dset.get_n_features()
        C = self._dset.get_n_classes()
        N = Y.shape[0]

        Y_spread = Y.view(N, 1, G).repeat(1, C + 1, 1)

        delta_tilde = torch.exp(self._variables["log_delta"])
        mean = delta_tilde * self._data["rho"]
        mean2 = torch.mm(design, self._variables["mu"].T)  ## N x P * P x G
        mean2 = mean2.view(-1, G, 1).repeat(1, 1, C + 1)
        mean = mean + mean2

        # now do the variance modelling
        p = torch.sigmoid(self._variables["p"])

        sigma = torch.exp(self._variables["log_sigma"])
        v1 = (self._data["rho"] * p).T * sigma
        v2 = torch.pow(sigma, 2) * (1 - torch.pow(self._data["rho"] * p, 2)).T

        v1 = v1.view(1, C + 1, G, 1).repeat(N, 1, 1,
                                            1)  # extra 1 is the "rank"
        v2 = v2.view(1, C + 1, G).repeat(N, 1, 1) + 1e-6

        dist = LowRankMultivariateNormal(loc=torch.exp(mean).permute(0, 2, 1),
                                         cov_factor=v1,
                                         cov_diag=v2)

        log_p_y_on_c = dist.log_prob(Y_spread)

        gamma, log_gamma = self._recog.forward(X)
        log_alpha = F.log_softmax(self._variables["alpha_logits"], dim=0)
        alpha = F.softmax(self._variables["alpha_logits"], dim=0)
        mix_prior = self._alpha_prior.log_prob(alpha)

        elbo = (gamma *
                (log_p_y_on_c + log_alpha - log_gamma)).sum() + mix_prior

        return -elbo

    def fit(
        self,
        max_epochs: int = 50,
        learning_rate: float = 1e-3,
        batch_size: int = 128,
        delta_loss: float = 1e-3,
        delta_loss_batch: int = 10,
        msg: str = "",
    ) -> None:
        """Runs train loops until the convergence reaches delta_loss for
        delta_loss_batch sizes or for max_epochs number of times

        :param max_epochs: number of train loop iterations, defaults to 50
        :param learning_rate: the learning rate, defaults to 0.01
        :param batch_size: the batch size, defaults to 128
        :param delta_loss: stops iteration once the loss rate reaches
            delta_loss, defaults to 0.001
        :param delta_loss_batch: the batch size to consider delta loss,
            defaults to 10
        :param msg: iterator bar message, defaults to empty string
        """
        if self._dset is None:
            raise Exception("the dataset is not provided")
        # Make dataloader
        dataloader = DataLoader(self._dset,
                                batch_size=min(batch_size, len(self._dset)),
                                shuffle=True)

        # Run training loop
        losses: List[torch.Tensor] = []
        per = torch.tensor(1)

        # Construct optimizer
        opt_params = list(self._variables.values()) + list(
            self._recog.parameters())

        optimizer = torch.optim.Adam(opt_params, lr=learning_rate)

        _, exprs_X, _ = self._dset[:]  # calls dset.get_item

        iterator = trange(
            max_epochs,
            desc="training restart" + msg,
            unit="epochs",
            bar_format=
            "{l_bar}{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]",
        )
        for ep in iterator:
            # for ep in range(max_epochs):
            L = None
            loss = torch.tensor(0.0, dtype=self._dtype)
            for batch in dataloader:
                Y, X, design = batch
                optimizer.zero_grad()
                L = self._forward(Y, X, design)
                L.backward()
                optimizer.step()
                with torch.no_grad():
                    loss = loss + L
            if len(losses) > 0:
                per = abs((loss - losses[-1]) / losses[-1])
            losses.append(loss)
            iterator.set_postfix_str("current loss: " +
                                     str(round(float(loss), 1)))

            if per <= delta_loss:
                self._is_converged = True
                iterator.close()
                break

        # Save output
        self._assignment = pd.DataFrame(
            self._recog.forward(exprs_X)[0].detach().cpu().numpy())
        self._assignment.columns = self._dset.get_classes() + ["Other"]
        self._assignment.index = self._dset.get_cell_names()

        if self._losses.shape[0] == 0:
            self._losses = torch.tensor(losses)
        else:
            self._losses = torch.cat((self._losses.view(
                self._losses.shape[0]), torch.tensor(losses)),
                                     dim=0)

    def predict(self, new_dset: pd.DataFrame) -> np.array:
        """Feed `new_dset` to the recognition net to get a prediction.

        :param new_dset: the dataset to be predicted
        :return: the resulting cell type assignment
        """
        _, exprs_X, _ = new_dset[:]
        g = pd.DataFrame(
            self._recog.forward(exprs_X)[0].detach().cpu().numpy())
        return g

    def get_recognet(self) -> TypeRecognitionNet:
        """Getter for the recognition net.

        :return: the trained recognition net
        """
        return self._recog

    def _most_likely_celltype(
        self,
        row: pd.DataFrame,
        threshold: float,
        cell_types: List[str],
        assignment_type: str,
    ) -> str:
        """Given a row of the assignment matrix, return the most likely cell type

        :param row: the row of cell assignment matrix to be evaluated
        :param threshold: the higher bound of the maximun probability to classify a cell as `Unknown`
        :param cell_types: the names of cell types, in the same order as the features of the row
        :param assignment_type: See
        :meth:`astir.CellTypeModel.get_celltypes` for full documentation
        :return: the most likely cell type of this cell
        """
        row = row.values
        max_prob = np.max(row)

        if assignment_type == "threshold":
            if max_prob < threshold:
                return "Unknown"
        elif assignment_type == "max":
            if sum(row == max_prob) > 1:
                return "Unknown"

        return cell_types[np.argmax(row)]

    def get_celltypes(
        self,
        threshold: float = 0.7,
        assignment_type: str = "threshold",
        prob_assign: Optional[pd.DataFrame] = None,
    ) -> pd.DataFrame:
        """Get the most likely cell types. A cell is assigned to a cell type
        if the probability is greater than threshold.
        If no cell types have a probability higher than threshold,
        then "Unknown" is returned.

        :param assignment_type: either 'threshold' or 'max'. If threshold,
            type assignment is based on whether the probability threshold is
            above prob_assignment. If 'max', type assignment is based on the max
            probability value or "unknown" if there are multiple max
            probabilities. Defaults to 'threshold'.
        :param threshold: the probability threshold above which a cell is
            assigned to a cell type, defaults to 0.7
        :return: a data frame with most likely cell types for each
        """
        if prob_assign is None:
            type_probability = self.get_assignment()
        else:
            type_probability = prob_assign

        if assignment_type != "threshold" and assignment_type != "max":
            warnings.warn("Wrong assignment type. Defaults the assignment "
                          "type to threshold.")
            assignment_type = "threshold"

        if assignment_type == "max" and prob_assign is not None:
            warnings.warn("Assignment type is 'max' but probability "
                          "threshold value was passed in. Probability "
                          "threshold value will be ignored.")

        cell_types = list(type_probability.columns)

        cell_type_assignments = type_probability.apply(
            self._most_likely_celltype,
            axis=1,
            assignment_type=assignment_type,
            threshold=threshold,
            cell_types=cell_types,
        )
        cell_type_assignments = pd.DataFrame(cell_type_assignments)
        cell_type_assignments.columns = ["cell_type"]

        return cell_type_assignments

    def _compare_marker_between_types(
        self,
        curr_type: str,
        celltype_to_compare: str,
        marker: str,
        cell_types: List[str],
        alpha: float = 0.05,
    ) -> Optional[dict]:
        """For two cell types and a protein, ensure marker
        is expressed at higher level for curr_type than celltype_to_compare

        :param curr_type: the cell type to assess
        :param celltype_to_compare: all the cell types that shouldn't highly express this marker
        :param marker: the marker protein for curr_type
        :param cell_types: list of cell types assigned for cells
        :param alpha:
        :return:
        """
        if self._dset is None:
            raise Exception("the dataset is not provided")
        current_marker_ind = np.array(self._dset.get_features()) == marker

        cells_x = np.array(cell_types) == curr_type
        cells_y = np.array(cell_types) == celltype_to_compare

        # x - cells whose cell types' marker protein is marker
        # y - cells whose cell types' marker protein is not marker
        x = self._dset.get_exprs().detach().cpu().numpy()[cells_x,
                                                          current_marker_ind]
        y = self._dset.get_exprs().detach().cpu().numpy()[cells_y,
                                                          current_marker_ind]

        stat = np.NaN
        pval = np.Inf
        note: Optional[str] = "Only 1 cell in a type: comparison not possible"

        if len(x) > 1 and len(y) > 1:
            tt = stats.ttest_ind(x, y)
            stat = tt.statistic
            pval = tt.pvalue
            note = None

        if not (stat > 0 and pval < alpha):
            rdict = {
                "current_marker": marker,
                "curr_type": curr_type,
                "celltype_to_compare": celltype_to_compare,
                "mean_A": x.mean(),
                "mean_Y": y.mean(),
                "p-val": pval,
                "note": note,
            }

            return rdict

        return None

    def plot_clustermap(
        self,
        plot_name: str = "celltype_protein_cluster.png",
        threshold: float = 0.7,
        figsize: Tuple[float, float] = (7.0, 5.0),
        prob_assign: Optional[pd.DataFrame] = None,
    ) -> None:
        """Save the heatmap of protein content in cells with cell types labeled.

        :param plot_name: name of the plot, extension(e.g. .png or .jpg) is needed, defaults to "celltype_protein_cluster.png"
        :param threshold: the probability threshold above which a cell is assigned to a cell type, defaults to 0.7
        :param figsize: the size of the figure, defaults to (7.0, 5.0)
        """
        if self._dset is None:
            raise Exception("the dataset is not provided")
        expr_df = self._dset.get_exprs_df()
        scaler = StandardScaler()
        for feature in expr_df.columns:
            expr_df[feature] = scaler.fit_transform(
                expr_df[feature].values.reshape(
                    (expr_df[feature].shape[0], 1)))

        expr_df["cell_type"] = self.get_celltypes(threshold=threshold,
                                                  prob_assign=prob_assign)
        expr_df = expr_df.sort_values(by=["cell_type"])
        types = expr_df.pop("cell_type")
        types_uni = types.unique()

        lut = dict(zip(types_uni, sns.color_palette("BrBG", len(types_uni))))
        col_colors = pd.DataFrame(types.map(lut))
        cm = sns.clustermap(
            expr_df.T,
            xticklabels=False,
            cmap="vlag",
            col_cluster=False,
            col_colors=col_colors,
            figsize=figsize,
        )

        for t in types_uni:
            cm.ax_col_dendrogram.bar(0, 0, color=lut[t], label=t, linewidth=0)
        cm.ax_col_dendrogram.legend(title="Cell Types",
                                    loc="center",
                                    ncol=3,
                                    bbox_to_anchor=(0.8, 0.8))
        cm.savefig(plot_name, dpi=150)

    def diagnostics(self, cell_type_assignments: list,
                    alpha: float) -> pd.DataFrame:
        """Run diagnostics on cell type assignments

        See :meth:`astir.Astir.diagnostics_celltype` for full documentation
        """
        if self._dset is None:
            raise Exception("the dataset is not provided")
        problems = []

        # Want to construct a data frame that models rho with
        # cell type names on the columns and feature names on the rows
        g_df = pd.DataFrame(self._data["rho"].detach().cpu().numpy())
        g_df.columns = self._dset.get_classes() + ["Other"]
        g_df.index = self._dset.get_features()

        for curr_type in self._dset.get_classes():
            if not curr_type in cell_type_assignments:
                continue

            current_markers = g_df.index[g_df[curr_type] == 1]

            for current_marker in current_markers:
                # find all the cell types that shouldn't highly express this marker
                celltypes_to_compare = g_df.columns[g_df.loc[current_marker] ==
                                                    0]

                for celltype_to_compare in celltypes_to_compare:
                    if not celltype_to_compare in cell_type_assignments:
                        continue

                    is_problem = self._compare_marker_between_types(
                        curr_type,
                        celltype_to_compare,
                        current_marker,
                        cell_type_assignments,
                        alpha,
                    )

                    if is_problem is not None:
                        problems.append(is_problem)

        col_names = [
            "feature",
            "should be expressed higher in",
            "than",
            "mean cell type 1",
            "mean cell type 2",
            "p-value",
            "note",
        ]
        df_issues = None
        if len(problems) > 0:
            df_issues = pd.DataFrame(problems)
            df_issues.columns = col_names
        else:
            df_issues = pd.DataFrame(columns=col_names)

        return df_issues