Пример #1
0
def create_puppet_workflow(regression_class=_RegressionWorkflowMixin,
                           base_class=workflow.WorkflowBase,
                           result_processor_class=NoOutputRP):

    puppet_parent = workflow._factory_build_inferelator(regression=regression_class, workflow=base_class)

    class PuppetClass(puppet_parent):
        """
        Standard workflow except it takes all the data as references to __init__ instead of as filenames on disk or
        as environment variables, and returns the model AUPR and edge counts without writing files (unless told to)
        """

        write_network = True
        network_file_name = None
        pr_curve_file_name = None
        initialize_mp = False

        def __init__(self, data, prior_data, gs_data):
            self.data = data
            self.priors_data = prior_data
            self.gold_standard = gs_data
            super(PuppetClass, self).__init__()

        def startup_run(self):
            # Skip all of the data loading
            self.process_priors_and_gold_standard()

        def create_output_dir(self, *args, **kwargs):
            pass

    return PuppetClass
 def setUp(self):
     self.workflow = workflow._factory_build_inferelator(
         regression=FakeRegression, workflow=tfa_workflow.TFAWorkFlow)()
     self.workflow.input_dir = os.path.join(my_dir, "../../data/dream4")
     self.workflow.expression_matrix_file = default.DEFAULT_EXPRESSION_FILE
     self.workflow.tf_names_file = default.DEFAULT_TFNAMES_FILE
     self.workflow.meta_data_file = default.DEFAULT_METADATA_FILE
     self.workflow.priors_file = default.DEFAULT_PRIORS_FILE
     self.workflow.gold_standard_file = default.DEFAULT_GOLDSTANDARD_FILE
     self.workflow.get_data()
class DownsampleDataWorkflow(workflow._factory_build_inferelator(regression="stars", workflow="single-cell")):

    sample_ratio = None
    sample_seed = 1000

    def startup_run(self):
        super(DownsampleDataWorkflow, self).startup_run()

        if self.sample_ratio == 1.:
            return

        rgen = np.random.default_rng(self.sample_seed)

        n_keep = int(self.data.num_obs * self.sample_ratio)
        n_keep = 1 if n_keep < 1 else n_keep

        self.data.get_random_samples(n_keep, random_gen=rgen, inplace=True, with_replacement=False)
class DownsampleDataWorkflow(
        workflow._factory_build_inferelator(regression="amusr",
                                            workflow="multitask")):

    sample_ratio = None
    sample_seed = 1000

    def startup_run(self):
        super(DownsampleDataWorkflow, self).startup_run()

        if self.sample_ratio == 1.:
            return

        rgen = np.random.default_rng(self.sample_seed)

        for tobj in self._task_objects:
            n_keep = int(tobj.data.num_obs * self.sample_ratio)
            n_keep = 1 if n_keep < 1 else n_keep

            tobj.data.get_random_samples(n_keep,
                                         random_gen=rgen,
                                         inplace=True,
                                         with_replacement=False)
Пример #5
0
 def test_abstractness(self):
     self.workflow = workflow._factory_build_inferelator(regression='base',
                                                         workflow=tfa_workflow.TFAWorkFlow)()
     with self.assertRaises(NotImplementedError):
         self.workflow.run_bootstrap([])
Пример #6
0
def create_task_data_class(workflow_class="single-cell"):
    task_parent = workflow._factory_build_inferelator(regression="base", workflow=workflow_class)

    class TaskData(task_parent):
        """
        TaskData is a workflow object which only loads and preprocesses data from files.
        """

        task_name = None
        tasks_from_metadata = False
        meta_data_task_column = None

        task_workflow_class = str(workflow_class)

        str_attrs = ["input_dir", "expression_matrix_file", "tf_names_file", "meta_data_file", "priors_file"]

        def __str__(self):
            """
            Create a printable report of the settings in this TaskData object

            :return: Settings in str_attrs in a printable string
            :rtype: str
            """

            task_str = "{n}:\n\tWorkflow Class: {cl}\n".format(n=self.task_name, cl=self.task_workflow_class)
            for attr in self.str_attrs:
                try:
                    task_str += "\t{attr}: {val}\n".format(attr=attr, val=getattr(self, attr))
                except AttributeError:
                    task_str += "\t{attr}: Nonexistant\n".format(attr=attr)
            return task_str

        def __init__(self):
            if self._file_format_settings is None:
                self._file_format_settings = dict()

        def initialize_multiprocessing(self):
            """
            Don't do anything with multiprocessing in this object
            """
            pass

        def startup(self):
            raise NotImplementedError

        def startup_run(self):
            raise NotImplementedError

        def get_data(self):
            """
            Load all the data and then return a list of references to TaskData objects
            There will be multiple objects returned if tasks_from_metadata is set.
            If tasks_from_metadata is not set, the list contains only this task (self)

            :return: List of TaskData objects with loaded data
            :rtype: list(TaskData)
            """
            Debug.vprint("Loading data for task {task_name}".format(task_name=self.task_name))
            super(TaskData, self).get_data()

            if self.tasks_from_metadata:
                return self.separate_tasks_by_metadata()
            else:
                return [self]

        def validate_data(self):
            """
            Don't validate data in TaskData. The parent workflow will check.
            """

            pass

        def set_run_parameters(self):
            """
            Set parameters used during runtime
            """

            warnings.warn("Task-specific `num_bootstraps` and `random_seed` is not supported. Set on parent workflow.")

        def process_priors_and_gold_standard(self, gold_standard=None, cv_flag=None, cv_axis=None, shuffle_priors=None,
                                             add_prior_noise=None):
            """
            Make sure that the priors for this task are correct

            This will remove circularity from the task priors based on the parent gold standard
            """

            gold_standard = self.gold_standard if gold_standard is None else gold_standard
            cv_flag = self.split_gold_standard_for_crossvalidation if cv_flag is None else cv_flag
            cv_axis = self.cv_split_axis if cv_axis is None else cv_axis
            shuffle_priors = self.shuffle_prior_axis if shuffle_priors is None else shuffle_priors
            add_prior_noise = self.add_prior_noise if add_prior_noise is None else add_prior_noise

            # Remove circularity from the gold standard
            if cv_flag:
                self.priors_data, _ = self.prior_manager._remove_prior_circularity(self.priors_data, gold_standard,
                                                                                   split_axis=cv_axis)

            if self.tf_names is not None:
                self.priors_data = self.prior_manager.filter_to_tf_names_list(self.priors_data, self.tf_names)

            # Filter priors and expression to a list of genes
            self.filter_to_gene_list()

            # Shuffle prior labels
            if shuffle_priors is not None:
                self.priors_data = self.prior_manager.shuffle_priors(self.priors_data, shuffle_priors, self.random_seed)

            if add_prior_noise is not None:
                self.priors_data = self.prior_manager.add_prior_noise(self.priors_data, add_prior_noise,
                                                                      self.random_seed)

            if min(self.priors_data.shape) == 0:
                raise ValueError("Priors for task {n} have an axis of length 0".format(n=self.task_name))

        def separate_tasks_by_metadata(self, meta_data_column=None):
            """
            Take a single expression matrix and break it into multiple dataframes based on meta_data. Return a list of
            TaskData objects which have the task-specific data loaded into them

            :param meta_data_column: Meta_data column which corresponds to task ID
            :type meta_data_column: str
            :return new_task_objects: List of the TaskData objects with only one task's data each
            :rtype: list(TaskData)

            """

            if self.data is None:
                raise ValueError("No data has been loaded prior to `separate_tasks_by_metadata`")

            meta_data_column = meta_data_column if meta_data_column is not None else self.meta_data_task_column
            if meta_data_column is None:
                raise ValueError("tasks_from_metadata is set but meta_data_task_column is not")
            elif meta_data_column not in self.data.meta_data:
                msg = "meta_data_task_column is not found in task {t}".format(t=str(self))
                raise ValueError(msg)

            new_task_objects = list()
            tasks = self.data.meta_data[meta_data_column].unique().tolist()
            Debug.vprint("Creating {n} tasks from metadata column {col}".format(n=len(tasks), col=meta_data_column),
                         level=0)

            # Remove data references from self
            data = self.data
            self.data = None

            for task in tasks:
                # Copy this object
                task_obj = copy.deepcopy(self)

                # Get an index of the stuff to keep
                task_idx = data.meta_data[meta_data_column] == task

                # Reset expression matrix, metadata, and task_name in the copy
                task_obj.data = data.subset_copy(row_index=task_idx)
                task_obj.data.name = task
                task_obj.task_name = task
                new_task_objects.append(task_obj)

            Debug.vprint("Separated data into {ntask} tasks".format(ntask=len(new_task_objects)), level=0)

            return new_task_objects

    return TaskData