Exemplo n.º 1
0
    def generate_cells(self,
                       cells_no,
                       checkpoint=None,
                       sess=None,
                       save_path=None):
        """
        Method that generate cells from the current model.

        Parameters
        ----------
        cells_no : int
            Number of cells to be generated.
        checkpoint : str / None
            Path to the checkpoint from which to load the model.
            If None, uses the current model loaded in the session.
            Default is None.
        sess : Session
            The TF Session in use.
            If None, a Session is created.
            Default is None.
        save_path : str
            Path in which to write the generated cells.
            If None, the cells are only returned and not written.
            Default is None.

        Returns
        -------
        fake_cells : Numpy array
            2-D Array with the gene expression matrix of the generated cells.
        """

        if sess is None:
            sess = tf.Session()

        if checkpoint is not None:
            saver = tf.train.Saver()
            saver.restore(sess, tf.train.latest_checkpoint(checkpoint))

        batches_no = int(np.ceil(cells_no / self.batch_size))
        fake_cells_tensor = self.generator.fake_outputs
        is_training = self.generator.is_training

        eval_feed_dict = {is_training: False}

        fake_cells = []
        for i_batch in range(batches_no):
            fc = sess.run([fake_cells_tensor], feed_dict=eval_feed_dict)
            fake_cells.append(fc)

        fake_cells = np.array(fake_cells, dtype=np.float32)
        fake_cells = fake_cells.reshape((-1, fake_cells_tensor.shape[1]))

        fake_cells = fake_cells[0:cells_no]

        rescale(fake_cells, scaling=self.scaling, scale_value=self.scale_value)

        if save_path is not None:
            save_generated_cells(fake_cells, save_path)

        return fake_cells
Exemplo n.º 2
0
    def read_valid_cells(self, sess, cells_no):
        """
        Method that reads a given number of cells from the validation set.

        Parameters
        ----------
        sess : Session
            The TF Session in use.
        cells_no : int
            Number of validation cells to read.

        Returns
        -------
        real_cells : numpy array
            Matrix with the required amount of validation cells.
        """

        batches_no = int(np.ceil(cells_no // self.batch_size))
        real_cells = []
        for i_batch in range(batches_no):
            test_inputs = sess.run([self.test_cells])
            real_cells.append(test_inputs)

        real_cells = np.array(real_cells, dtype=np.float32)
        real_cells = real_cells.reshape((-1, self.test_cells.shape[1]))

        real_cells = rescale(real_cells,
                             scaling=self.scaling,
                             scale_value=self.scale_value)

        return real_cells
Exemplo n.º 3
0
    def generate_cells(self,
                       cells_no,
                       clusters_ratios=None,
                       sess=None,
                       save_path=None,
                       checkpoint=None):
        """
        Method that generate cells from the current model.

        Parameters
        ----------
        cells_no : int or list
            Numbers of cells per cluster to be generated.
            If the clusters_ratios are provided, should be an int (total number of cells).
            If cluster_ratios is None, should be a list of number of cells per cluster.
        clusters_ratios : numpy array
            List containing the different cluster ratios to use for
            the conditional generation.
            Default is None.
        sess : Session
            The TF Session in use.
            If None, a Session is created.
            Default is None.
        save_path : str
            Path in which to write the generated cells.
            If None, the cells are only returned and not written.
            Default is None.
        checkpoint : str /None
            Path to the checkpoint from which to load the model.
            If None, uses the current model loaded in the session.
            Default is None.

        Returns
        -------
        fake_cells : Numpy array
            2-D Array with the gene expression matrix of the generated cells.
        fake_labels : Numpy array
            Array containing the cluster index of the generated cells.
        """

        if sess is None:
            sess = tf.Session()

        if checkpoint is not None:
            saver = tf.train.Saver()
            saver.restore(sess, tf.train.latest_checkpoint(checkpoint))

        fake_cells = np.empty((0, self.genes_no), dtype=np.float32)
        fake_labels = np.empty([0, 1], dtype=np.int32)

        if clusters_ratios is None and len(cells_no) > 1:

            for cluster, cells_per_cluster in enumerate(cells_no):
                if int(cells_per_cluster) == 0:
                    continue
                clusters_ratios = np.zeros((1, self.clusters_no),
                                           dtype=np.float)
                clusters_ratios[0][cluster] = 1

                fc, fl = self.generate_cells(sess=sess,
                                             checkpoint=checkpoint,
                                             cells_no=int(cells_per_cluster),
                                             clusters_ratios=clusters_ratios)

                fake_cells = np.append(fake_cells, fc, axis=0)
                fake_labels = np.append(fake_labels, fl)

        else:

            batches_no = int(np.ceil(cells_no / self.batch_size))

            clusters_ratios_ph = self.generator.clusters_ratios
            fake_labels_tensor = self.generator.input_clusters
            is_training = self.generator.is_training
            fake_cells_tensor = self.generator.fake_outputs

            eval_feed_dict = {
                is_training: False,
                clusters_ratios_ph: clusters_ratios
            }

            for i_batch in range(batches_no):
                fc, fl = sess.run([fake_cells_tensor, fake_labels_tensor],
                                  feed_dict=eval_feed_dict)
                fake_cells = np.append(fake_cells, fc, axis=0)
                fake_labels = np.append(fake_labels, fl)

            fake_labels = fake_labels[0:cells_no]
            fake_cells = fake_cells[0:cells_no]

        rescale(fake_cells, scaling=self.scaling, scale_value=self.scale_value)

        if save_path is not None:
            save_generated_cells(fake_cells, save_path, fake_labels)

        return fake_cells, fake_labels