コード例 #1
0
class MorphologyWorkflow(WorkflowBase):
    # compute cell features or nucleus features?
    compute_cell_features = luigi.BoolParameter()

    # paths to raw data and segmentations
    # if the raw path is None, we don't compute intensity features
    raw_path = luigi.Parameter(default=None)
    # we always need the nucleus segmentation
    nucleus_segmentation_path = luigi.Parameter()
    # we only need the cell segmentation if we compute cell morphology features
    cell_segmentation_path = luigi.Parameter(default=None)
    # we only need the chromatin segmentation if we compute nucleus features
    chromatin_segmentation_path = luigi.Parameter(default=None)

    # the scale used for computation, relative to the raw scale
    scale = luigi.IntParameter(default=3)

    # the input tables paths for the default table, the
    # nucleus mapping table and the region mapping table
    in_table_path = luigi.Parameter()
    # only need the mapping paths for the nucleus features
    nucleus_mapping_path = luigi.Parameter(default=None)
    region_mapping_path = luigi.Parameter(default=None)

    # minimum and maximum sizes for objects / bounding box
    min_size = luigi.IntParameter()
    max_size = luigi.IntParameter(default=None)
    max_bb = luigi.IntParameter()

    output_path = luigi.Parameter()

    def requires(self):
        prefix = 'cells' if self.compute_cell_features else 'nuclei'
        out_prefix = os.path.join(self.tmp_folder, 'sub_table_%s' % prefix)
        morpho_task = getattr(morpho_tasks, self._get_task_name('Morphology'))
        dep = morpho_task(
            tmp_folder=self.tmp_folder,
            config_dir=self.config_dir,
            dependency=self.dependency,
            max_jobs=self.max_jobs,
            compute_cell_features=self.compute_cell_features,
            raw_path=self.raw_path,
            nucleus_segmentation_path=self.nucleus_segmentation_path,
            cell_segmentation_path=self.cell_segmentation_path,
            chromatin_segmentation_path=self.chromatin_segmentation_path,
            in_table_path=self.in_table_path,
            output_prefix=out_prefix,
            nucleus_mapping_path=self.nucleus_mapping_path,
            region_mapping_path=self.region_mapping_path,
            min_size=self.min_size,
            max_size=self.max_size,
            max_bb=self.max_bb)
        dep = MergeTables(output_prefix=out_prefix,
                          output_path=self.output_path,
                          max_jobs=self.max_jobs,
                          dependency=dep)

        return dep

    @staticmethod
    def get_config():
        configs = super(MorphologyWorkflow, MorphologyWorkflow).get_config()
        configs.update(
            {'morphology': morpho_tasks.MorphologyLocal.default_task_config()})
        return configs
コード例 #2
0
ファイル: sge.py プロジェクト: tianshuli-ext-aa/luigi
class SGEJobTask(luigi.Task):
    """Base class for executing a job on SunGrid Engine

    Override ``work()`` (rather than ``run()``) with your job code.

    Parameters:

    - n_cpu: Number of CPUs (or "slots") to allocate for the Task. This
          value is passed as ``qsub -pe {pe} {n_cpu}``
    - parallel_env: SGE parallel environment name. The default is "orte",
          the parallel environment installed with MIT StarCluster. If you
          are using a different cluster environment, check with your
          sysadmin for the right pe to use. This value is passed as {pe}
          to the qsub command above.
    - shared_tmp_dir: Shared drive accessible from all nodes in the cluster.
          Task classes and dependencies are pickled to a temporary folder on
          this drive. The default is ``/home``, the NFS share location setup
          by StarCluster
    - job_name_format: String that can be passed in to customize the job name
        string passed to qsub; e.g. "Task123_{task_family}_{n_cpu}...".
    - job_name: Exact job name to pass to qsub.
    - run_locally: Run locally instead of on the cluster.
    - poll_time: the length of time to wait in order to poll qstat
    - dont_remove_tmp_dir: Instead of deleting the temporary directory, keep it.
    - no_tarball: Don't create a tarball of the luigi project directory.  Can be
        useful to reduce I/O requirements when the luigi directory is accessible
        from cluster nodes already.

    """

    n_cpu = luigi.IntParameter(default=2, significant=False)
    shared_tmp_dir = luigi.Parameter(default='/home', significant=False)
    parallel_env = luigi.Parameter(default='orte', significant=False)
    job_name_format = luigi.Parameter(
        significant=False,
        default=None,
        description="A string that can be "
        "formatted with class variables to name the job with qsub.")
    job_name = luigi.Parameter(significant=False,
                               default=None,
                               description="Explicit job name given via qsub.")
    run_locally = luigi.BoolParameter(
        significant=False, description="run locally instead of on the cluster")
    poll_time = luigi.IntParameter(
        significant=False,
        default=POLL_TIME,
        description="specify the wait time to poll qstat for the job status")
    dont_remove_tmp_dir = luigi.BoolParameter(
        significant=False,
        description="don't delete the temporary directory used (for debugging)"
    )
    no_tarball = luigi.BoolParameter(
        significant=False,
        description="don't tarball (and extract) the luigi project files")

    def __init__(self, *args, **kwargs):
        super(SGEJobTask, self).__init__(*args, **kwargs)
        if self.job_name:
            # use explicitly provided job name
            pass
        elif self.job_name_format:
            # define the job name with the provided format
            self.job_name = self.job_name_format.format(
                task_family=self.task_family, **self.__dict__)
        else:
            # default to the task family
            self.job_name = self.task_family

    def _fetch_task_failures(self):
        if not os.path.exists(self.errfile):
            logger.info('No error file')
            return []
        with open(self.errfile, "r") as f:
            errors = f.readlines()
        if errors == []:
            return errors
        if errors[0].strip(
        ) == 'stdin: is not a tty':  # SGE complains when we submit through a pipe
            errors.pop(0)
        return errors

    def _init_local(self):

        # Set up temp folder in shared directory (trim to max filename length)
        base_tmp_dir = self.shared_tmp_dir
        random_id = '%016x' % random.getrandbits(64)
        folder_name = self.task_id + '-' + random_id
        self.tmp_dir = os.path.join(base_tmp_dir, folder_name)
        max_filename_length = os.fstatvfs(0).f_namemax
        self.tmp_dir = self.tmp_dir[:max_filename_length]
        logger.info("Tmp dir: %s", self.tmp_dir)
        os.makedirs(self.tmp_dir)

        # Dump the code to be run into a pickle file
        logging.debug("Dumping pickled class")
        self._dump(self.tmp_dir)

        if not self.no_tarball:
            # Make sure that all the class's dependencies are tarred and available
            # This is not necessary if luigi is importable from the cluster node
            logging.debug("Tarballing dependencies")
            # Grab luigi and the module containing the code to be run
            packages = [luigi
                        ] + [__import__(self.__module__, None, None, 'dummy')]
            luigi.hadoop.create_packages_archive(
                packages, os.path.join(self.tmp_dir, "packages.tar"))

    def run(self):
        if self.run_locally:
            self.work()
        else:
            self._init_local()
            self._run_job()
            # The procedure:
            # - Pickle the class
            # - Tarball the dependencies
            # - Construct a qsub argument that runs a generic runner function with the path to the pickled class
            # - Runner function loads the class from pickle
            # - Runner class untars the dependencies
            # - Runner function hits the button on the class's work() method

    def work(self):
        """Override this method, rather than ``run()``,  for your actual work."""
        pass

    def _dump(self, out_dir=''):
        """Dump instance to file."""
        with self.no_unpicklable_properties():
            self.job_file = os.path.join(out_dir, 'job-instance.pickle')
            if self.__module__ == '__main__':
                d = pickle.dumps(self)
                module_name = os.path.basename(sys.argv[0]).rsplit('.', 1)[0]
                d = d.replace('(c__main__', "(c" + module_name)
                open(self.job_file, "w").write(d)
            else:
                pickle.dump(self, open(self.job_file, "w"))

    def _run_job(self):

        # Build a qsub argument that will run sge_runner.py on the directory we've specified
        runner_path = sge_runner.__file__
        if runner_path.endswith("pyc"):
            runner_path = runner_path[:-3] + "py"
        job_str = 'python {0} "{1}" "{2}"'.format(
            runner_path, self.tmp_dir, os.getcwd()
        )  # enclose tmp_dir in quotes to protect from special escape chars
        if self.no_tarball:
            job_str += ' "--no-tarball"'

        # Build qsub submit command
        self.outfile = os.path.join(self.tmp_dir, 'job.out')
        self.errfile = os.path.join(self.tmp_dir, 'job.err')
        submit_cmd = _build_qsub_command(job_str, self.task_family,
                                         self.outfile, self.errfile,
                                         self.parallel_env, self.n_cpu)
        logger.debug('qsub command: \n' + submit_cmd)

        # Submit the job and grab job ID
        output = subprocess.check_output(submit_cmd, shell=True)
        self.job_id = _parse_qsub_job_id(output)
        logger.debug("Submitted job to qsub with response:\n" + output)

        self._track_job()

        # Now delete the temporaries, if they're there.
        if (self.tmp_dir and os.path.exists(self.tmp_dir)
                and not self.dont_remove_tmp_dir):
            logger.info('Removing temporary directory %s' % self.tmp_dir)
            subprocess.call(["rm", "-rf", self.tmp_dir])

    def _track_job(self):
        while True:
            # Sleep for a little bit
            time.sleep(self.poll_time)

            # See what the job's up to
            # ASSUMPTION
            qstat_out = subprocess.check_output(['qstat'])
            sge_status = _parse_qstat_state(qstat_out, self.job_id)
            if sge_status == 'r':
                logger.info('Job is running...')
            elif sge_status == 'qw':
                logger.info('Job is pending...')
            elif 'E' in sge_status:
                logger.error('Job has FAILED:\n' +
                             '\n'.join(self._fetch_task_failures()))
                break
            elif sge_status == 't' or sge_status == 'u':
                # Then the job could either be failed or done.
                errors = self._fetch_task_failures()
                if not errors:
                    logger.info('Job is done')
                else:
                    logger.error('Job has FAILED:\n' + '\n'.join(errors))
                break
            else:
                logger.info('Job status is UNKNOWN!')
                logger.info('Status is : %s' % sge_status)
                raise Exception(
                    "job status isn't one of ['r', 'qw', 'E*', 't', 'u']: %s" %
                    sge_status)
コード例 #3
0
class WithDefaultFalse(luigi.Task):
    x = luigi.BoolParameter(default=False)
コード例 #4
0
class _GenSamples(AutoLocalOutputMixin(base_path=LUIGI_COMPLETED_TARGETS_DIR),
                  LoadInputDictMixin, DeleteDepsRecursively, luigi.Task, ABC):
    dataset_settings = luigi.DictParameter()
    random_seed = luigi.Parameter()
    generate_positive_samples = luigi.BoolParameter()
    num_samples = luigi.IntParameter()

    def requires(self):
        if not self.generate_in_batch:
            GS = self.gen_sample_type
            reqs = [
                GS(dataset_settings=self.dataset_settings,
                   random_seed=self.random_seed,
                   generate_positive_sample=self.generate_positive_samples,
                   sample_number=sample_num)
                for sample_num in range(self.num_samples)
            ]
            reqs = {'samples': reqs}
        elif self.generate_in_batch and self.num_samples > MIN_SAMPLES:
            self.n_prev = np.floor(self.num_samples / SAMPLES_BASE).astype(int)
            reqs = {
                'prev':
                self.__class__(  # because the class will have
                    # been specialized by the factory GenSamples()
                    dataset_settings=self.dataset_settings,
                    random_seed=self.random_seed,
                    generate_positive_samples=self.generate_positive_samples,
                    num_samples=self.n_prev)
            }
        else:
            self.n_prev = 0
            reqs = {}
        self.reqs_ = reqs

        if self.in_memory:
            return {}
        return reqs

    def run(self):
        prev = {'X': None, 'y': None}
        if not self.generate_in_batch:
            samples = self.compute_or_load_requirements()['samples']
        else:  # self.generate_in_batch
            if self.num_samples > MIN_SAMPLES:
                prev = self.compute_or_load_requirements()['prev']

            f_GS = self.gen_sample_type(dataset_settings=self.dataset_settings,
                                        random_seed=self.random_seed,
                                        generate_positive_sample=self.
                                        generate_positive_samples).gen_sample

            samples = []
            for sn in range(self.n_prev, self.num_samples):
                # set_progress_percentage is a blocking network IO operation... lol
                # self.set_progress_percentage(100*(sn-self.n_prev)/n_to_make)
                samples.append(f_GS(sample_number=sn))
        X, y = zip(*samples)
        if prev['X'] is not None:
            X = self.x_concatenator((prev['X'], self.x_concatenator(X)))
        else:
            X = self.x_concatenator(X)
        if prev['y'] is not None:
            y = self.y_concatenator((prev['y'], self.y_concatenator(y)))
        else:
            y = self.y_concatenator(y)
        output = {'X': X, 'y': y}
        self.output_ = output

        if not self.dont_write_output:
            with self.output().open('w') as f:
                dill.dump(output, f, 2)
コード例 #5
0
class LambdaInvocationTask(workflow_tasks.PuppetTask):
    lambda_invocation_name = luigi.Parameter()
    manifest_file_path = luigi.Parameter()

    puppet_account_id = luigi.Parameter()
    should_use_sns = luigi.BoolParameter()
    should_use_product_plans = luigi.BoolParameter()
    include_expanded_from = luigi.BoolParameter()
    single_account = luigi.Parameter()
    is_dry_run = luigi.BoolParameter()

    def params_for_results_display(self):
        return {
            "puppet_account_id": self.puppet_account_id,
            "manifest_file_path": self.manifest_file_path,
            "lambda_invocation_name": self.lambda_invocation_name,
        }

    def requires(self):
        return {
            "manifest": manifest_tasks.ManifestTask(
                manifest_file_path=self.manifest_file_path,
                puppet_account_id=self.puppet_account_id,
            ),
        }

    def run(self):
        manifest = manifest_utils.Manifest(self.load_from_input("manifest"))

        lambda_invocation = manifest.get("lambda-invocations").get(self.lambda_invocation_name)

        dependencies = list()
        for dependency in lambda_invocation.get('depends_on', []):
            if isinstance(dependency, str):
                dependencies.append(
                    LaunchTask(
                        launch_name=dependency,
                        manifest_file_path=self.manifest_file_path,
                        puppet_account_id=self.puppet_account_id,
                        should_use_sns=self.should_use_sns,
                        should_use_product_plans=self.should_use_product_plans,
                        include_expanded_from=self.include_expanded_from,
                        single_account=self.single_account,
                        is_dry_run=self.is_dry_run,
                        execution_mode="hub",
                    )
                )
            else:
                dependency_type = dependency.get('type', 'launch')
                if dependency_type == "launch":
                    dependencies.append(
                        LaunchTask(
                            launch_name=dependency.get('name'),
                            manifest_file_path=self.manifest_file_path,
                            puppet_account_id=self.puppet_account_id,
                            should_use_sns=self.should_use_sns,
                            should_use_product_plans=self.should_use_product_plans,
                            include_expanded_from=self.include_expanded_from,
                            single_account=self.single_account,
                            is_dry_run=self.is_dry_run,
                            execution_mode="hub",
                        )
                    )
                elif dependency_type == "lambda-invocation":
                    dependencies.append(
                        self.__class__(
                            lambda_invocation_name=dependency.get('name'),
                            manifest_file_path=self.manifest_file_path,
                            puppet_account_id=self.puppet_account_id,
                            should_use_sns=self.should_use_sns,
                            should_use_product_plans=self.should_use_product_plans,
                            include_expanded_from=self.include_expanded_from,
                            single_account=self.single_account,
                            is_dry_run=self.is_dry_run,
                        )
                    )
        yield dependencies

        task_defs = manifest.get_task_defs_from_details(
            self.puppet_account_id,
            True,
            self.lambda_invocation_name,
            {},
            "lambda-invocations",
        )

        common_params = {
            "lambda_invocation_name": self.lambda_invocation_name,

            "function_name": lambda_invocation.get("function_name"),
            "qualifier": lambda_invocation.get("qualifier", "$LATEST"),
            "invocation_type": lambda_invocation.get("invocation_type"),

            "puppet_account_id": self.puppet_account_id,

            "parameters": lambda_invocation.get("parameters", {}),

            "launch_parameters": lambda_invocation.get('parameters', {}),
            "manifest_parameters": manifest.get('parameters', {}),
        }

        for task_def in task_defs:
            task_def_parameters = {
                "account_id": task_def.get('account_id'),
                "region": task_def.get('region'),
                "account_parameters": task_def.get('account_parameters'),
            }
            task_def_parameters.update(common_params)
            yield InvokeLambdaTask(**task_def_parameters)

        self.write_output(self.params_for_results_display())
コード例 #6
0
ファイル: hists.py プロジェクト: cms-btv-pog/jet-tagging-sf
class GetScaleFactorWeights(DatasetTask, GridWorkflow, law.LocalWorkflow):

    iteration = WriteHistograms.iteration
    file_merging = WriteHistograms.file_merging

    b_tagger = WriteHistograms.b_tagger
    optimize_binning = WriteHistograms.optimize_binning
    category_tags = WriteHistograms.category_tags

    normalize_cerrs = luigi.BoolParameter()

    def __init__(self, *args, **kwargs):
        super(GetScaleFactorWeights, self).__init__(*args, **kwargs)
        # set shifts
        if self.dataset_inst.is_data:
            raise Exception(
                "GetScaleFactorWeights task should only run for MC.")

        if self.normalize_cerrs:
            self.shifts = format_shifts(["c_stats1", "c_stats2"])
        else:
            jes_sources = self.config_inst.get_aux("jes_sources_{}".format(
                self.config_inst.get_aux("jes_scheme")))
            self.shifts = {"nominal"} | format_shifts(jes_sources, prefix="jes") | \
                format_shifts(["lf", "hf", "lf_stats1", "lf_stats2", "hf_stats1", "hf_stats2"])

    def workflow_requires(self):
        from analysis.tasks.measurement import BundleScaleFactors

        reqs = super(GetScaleFactorWeights, self).workflow_requires()

        if not self.cancel_jobs and not self.cleanup_jobs:
            reqs["meta"] = MergeMetaData.req(
                self,
                version=self.get_version(MergeMetaData),
                _prefer_cli=["version"])
            reqs["pu"] = CalculatePileupWeights.req(self)
            if not self.pilot:
                reqs["tree"] = MergeTrees.req(
                    self,
                    cascade_tree=-1,
                    version=self.get_version(MergeTrees),
                    _prefer_cli=["version"])

            reqs["sf"] = BundleScaleFactors.req(
                self,
                iteration=self.iteration,
                fix_normalization=False,
                version=self.get_version(BundleScaleFactors),
                include_cshifts=self.normalize_cerrs,
                _prefer_cli=["version"])

        return reqs

    def requires(self):
        from analysis.tasks.measurement import BundleScaleFactors

        reqs = {
            "tree":
            MergeTrees.req(self,
                           cascade_tree=self.branch,
                           branch=0,
                           version=self.get_version(MergeTrees),
                           _prefer_cli=["version", "workflow"]),
            "meta":
            MergeMetaData.req(self,
                              version=self.get_version(MergeMetaData),
                              _prefer_cli=["version"]),
        }
        reqs["pu"] = CalculatePileupWeights.req(self)
        reqs["sf"] = BundleScaleFactors.req(
            self,
            iteration=self.iteration,
            fix_normalization=False,
            version=self.get_version(BundleScaleFactors),
            include_cshifts=self.normalize_cerrs,
            _prefer_cli=["version"])
        return reqs

    def store_parts(self):
        c_err_part = "c_errors" if self.normalize_cerrs else "b_and_udsg"
        binning_part = "optimized" if self.optimize_binning else "default"

        return super(GetScaleFactorWeights, self).store_parts() + (self.b_tagger,) \
            + (self.iteration,) + (binning_part,) + (c_err_part,)

    def output(self):
        return self.wlcg_target("stats_{}.json".format(self.branch))

    def get_jec_identifier(self, shift):
        if shift.startswith("jes"):
            return "_" + shift
        else:
            return ""

    def get_scale_factors(self, sfs, shift):
        sf_hists = {}
        for category in sfs.GetListOfKeys():
            category_dir = sfs.Get(category.GetName())
            hist = category_dir.Get("sf")
            # decouple from open file
            hist.SetDirectory(0)

            sf_hists[category.GetName()] = hist

        btag_var = self.config_inst.get_aux("btaggers")[
            self.b_tagger]["variable"]
        identifier = self.get_jec_identifier(shift)

        def get_value(entry):
            scale_factors = []
            for jet_idx in range(1, 5):
                jet_pt = getattr(entry,
                                 "jet{}_pt{}".format(jet_idx, identifier))[0]
                jet_eta = getattr(entry,
                                  "jet{}_eta{}".format(jet_idx, identifier))[0]
                jet_flavor = getattr(
                    entry, "jet{}_flavor{}".format(jet_idx, identifier))[0]
                jet_btag = getattr(
                    entry, "jet{}_{}{}".format(jet_idx, btag_var,
                                               identifier))[0]

                # stop when number of jets is exceeded
                if jet_flavor < -999.:
                    break

                # find category in which the scale factor of the jet was computed to get correct histogram
                if abs(jet_flavor) == 5:
                    region = "hf"
                elif abs(jet_flavor) == 4:
                    region = "c"
                else:
                    region = "lf"

                if region == "c" and not self.normalize_cerrs:
                    continue
                elif region != "c" and self.normalize_cerrs:
                    continue

                category = self.category_getter.get_category(
                    jet_pt, abs(jet_eta), region)

                # get scale factor
                sf_hist = sf_hists[category.name]
                bin_idx = sf_hist.FindBin(jet_btag)
                scale_factor = sf_hist.GetBinContent(bin_idx)

                scale_factors.append((category, scale_factor))

            return scale_factors

        return get_value

    @law.decorator.notify
    def run(self):
        import ROOT

        inp = self.input()
        outp = self.output()
        outp.parent.touch(0o0770)

        self.category_getter = CategoryGetter(self.config_inst, self.b_tagger)

        # get processes
        if len(self.dataset_inst.processes) != 1:
            raise NotImplementedError(
                "only datasets with exactly one linked process can be"
                " handled, got {}".format(len(self.dataset_inst.processes)))
        processes = list(self.dataset_inst.processes.values())
        process = processes[0]

        # prepare dict for outputs
        # shift -> category -> sum weights/ sum weighted sfs
        output_data = {
            shift: defaultdict(lambda: defaultdict(float))
            for shift in self.shifts
        }

        # open the input file and get the tree
        with inp["tree"].load("READ", cache=False) as input_file:
            tree = input_file.Get("tree")
            self.publish_message("{} events in tree".format(tree.GetEntries()))

            # identifier for jec shifted variables
            for shift in self.shifts:
                jec_identifier = self.get_jec_identifier(shift)

                # pt aliases for jets
                for obj in ["jet1", "jet2", "jet3", "jet4"]:
                    tree.SetAlias(
                        "{0}_pt{1}".format(obj, jec_identifier),
                        "({0}_px{1}**2 + {0}_py{1}**2)**0.5".format(
                            obj, jec_identifier))
                # b-tagging alias
                btag_var = self.config_inst.get_aux("btaggers")[
                    self.b_tagger]["variable"]
                for obj in ["jet1", "jet2", "jet3", "jet4"]:
                    variable = self.config_inst.get_variable("{0}_{1}".format(
                        obj, btag_var))
                    tree.SetAlias(
                        variable.name + jec_identifier,
                        variable.expression.format(
                            **{"jec_identifier": jec_identifier}))
            # pt aliases for leptons
            for obj in ["lep1", "lep2"]:
                tree.SetAlias("{0}_pt".format(obj),
                              "({0}_px**2 + {0}_py**2)**0.5".format(obj))

            scale_factor_getters = {}
            with inp["sf"].load() as sf_file:
                for shift in self.shifts:
                    scale_factor_getters[shift] = self.get_scale_factors(
                        sf_file.Get(shift), shift)

            # get info to scale event weight to lumi
            x_sec = process.get_xsec(self.config_inst.campaign.ecm).nominal
            sum_weights = inp["meta"].load()["event_weights"]["sum"]
            lumi_factor = x_sec / sum_weights

            input_file.cd()
            with TreeExtender(tree) as te:
                # unpack all branches
                te.unpack_branch("*")
                for i, entry in enumerate(te):
                    if (i % 1000) == 0:
                        print "entry {}".format(i)
                    # get event weight
                    gen_weight = entry.gen_weight[0]
                    channel_id = entry.channel[0]
                    channel = self.config_inst.get_channel(channel_id)
                    lumi = self.config_inst.get_aux("lumi")[channel]

                    evt_weight = gen_weight * lumi * lumi_factor

                    for shift in self.shifts:
                        # event has to pass base selection
                        jec_identifier = self.get_jec_identifier(shift)
                        if getattr(entry, "jetmet_pass{}".format(
                                jec_identifier))[0] != 1:
                            continue

                        # calculate per-jet b-tagging weights
                        scale_factors = scale_factor_getters[shift](entry)
                        # save sum for latter normalization
                        for category, sf_value in scale_factors:
                            output_data[shift][category.name][
                                "sum_sf"] += sf_value * evt_weight
                            output_data[shift][
                                category.name]["sum_weights"] += evt_weight

        # save outputs
        self.output().dump(output_data, formatter="json", indent=4)
コード例 #7
0
class TensorFlowTask(luigi.Task):
    """Luigi wrapper for a TensorFlow task. To use, extend this class and provide values for the
    following properties:

    model_package = None        The name of the python package containing your model.
    model_name = None           The name of the python module containing your model.
                                Ex: if the model is in /foo/models/main.py, you would set
                                model_package = "models" and model_name = "main"
    gcp_project = None          The Google Cloud project id to run with ml-engine
    region = None               The GCP region if running with ml-engine, e.g. europe-west1
    model_name_suffix = None    A string suffix representing the model name, which will be appended
                                to the job name.
    runtime_version = None      The Google Cloud ML Engine runtime version for this job. Defaults to
                                the latest stable version. See
                                https://cloud.google.com/ml/docs/concepts/runtime-version-list for a
                                list of accepted versions.
    scale_tier = None           Specifies the machine types, the number of replicas for workers and
                                parameter servers. SCALE_TIER must be one of:
                                    basic, basic-gpu, basic-tpu, custom, premium-1, standard-1.

    Also, you can specify command line arguments for your trainer by overriding the
    `def tf_task_args(self)` method.
    """

    # Task properties
    model_name = luigi.Parameter(description="Name of the python model file")
    model_package = luigi.Parameter(description="Python package containing your model")
    model_package_path = luigi.Parameter(description="Absolute path to the model package")
    gcp_project = luigi.Parameter(description="GCP project", default=None)
    region = luigi.Parameter(description="GCP region", default=None)
    model_name_suffix = luigi.Parameter(description="String which will be appended to the job"
                                                    " name. Useful for finding jobs in the"
                                                    " ml-engine UI.", default=None)

    # Task parameters
    cloud = luigi.BoolParameter(description="Run on ml-engine")
    blocking = luigi.BoolParameter(default=True, description="Run in stream-logs/blocking mode")
    job_dir = luigi.Parameter(description="A job directory, used to store snapshots, logs and any "
                                          "other artifacts. A trailing '/' is required for "
                                          "'gs://' paths.")
    ml_engine_conf = luigi.Parameter(default=None,
                                     description="An ml-engine YAML configuration file.")
    tf_debug = luigi.BoolParameter(default=False, description="Run tf on debug mode")
    runtime_version = luigi.Parameter(default=None,
                                      description="The Google Cloud ML Engine runtime version "
                                      "for this job.")
    scale_tier = luigi.Parameter(default=None,
                                 description="Specifies the machine types, the number of replicas "
                                             "for workers and parameter servers.")

    def __init__(self, *args, **kwargs):
        super(TensorFlowTask, self).__init__(*args, **kwargs)

    def tf_task_args(self):
        """A list of args to pass to the tf main module."""
        return []

    def run(self):
        cmd = self._mk_cmd()
        logger.info("Running:\n```\n%s\n```", cmd)
        ret = subprocess.call(cmd, shell=True)
        if ret != 0:
            logger.error("Training failed. Aborting.")
            sys.exit(ret)
        logger.info("Training successful. Marking as done.")
        self._success_hook()

    def output(self):
        if is_gcs_path(self.get_job_dir()):
            return GCSFlagTarget(self.get_job_dir())
        else:
            # assume local filesystem otherwise
            return LocalTarget(self.get_job_dir())

    # TODO(rav): look into luigi hooks
    def _success_hook(self):
        success_file = self.get_job_dir().rstrip("/") + "/_SUCCESS"
        if is_gcs_path(self.get_job_dir()):
            from luigi.contrib.gcs import GCSClient
            client = GCSClient()
            client.put_string("", success_file)
        else:
            # assume local filesystem otherwise
            open(success_file, "a").close()

    def _mk_cmd(self):
        cmd = ["gcloud ml-engine"]
        if self.cloud:
            cmd.extend(self._mk_cloud_params())
        else:
            cmd.append("local train")

        cmd.extend(self._get_model_args())

        if self.tf_debug:
            cmd += ["--verbosity=debug"]

        cmd.extend(self._get_job_args())
        return " ".join(cmd)

    def get_job_dir(self):
        """Get job directory used to store snapshots, logs, final output and any other artifacts."""
        return self.job_dir

    def _mk_cloud_params(self):
        params = []
        if self.gcp_project:
            params.append("--project=%s" % self.gcp_project)
        import uuid
        params.append("jobs submit training %s_%s_%s" % (getpass.getuser(),
                                                         self.__class__.__name__,
                                                         str(uuid.uuid4()).replace("-", "_")))
        if self.region:
            params.append("--region=%s" % self.region)
        if self.ml_engine_conf:
            params.append("--config=%s" % self.ml_engine_conf)
        params.append("--job-dir=%s" % self.get_job_dir())
        if self.blocking:
            params.append("--stream-logs")  # makes the execution "blocking"
        if self.runtime_version:
            params.append("--runtime-version=%s" % self.runtime_version)
        if self.scale_tier:
            params.append("--scale-tier=%s" % self.scale_tier)
        return params

    def _get_model_args(self):
        args = []
        if self.model_package_path:
            args.append("--package-path=%s" % self.model_package_path)
        if self.model_name:
            module_name = self.model_name
            if self.model_package:
                module_name = "{package}.{module}".format(package=self.model_package,
                                                          module=module_name)
            args.append("--module-name=" + module_name)
        return args

    def _get_job_args(self):
        args = ["--"]
        args.extend(self._get_input_args())
        if not self.cloud:
            args.append("--job-dir=%s" % self.get_job_dir())
        args.extend(self.tf_task_args())
        return args

    def _get_input_args(self):
        job_input = self.input()
        if isinstance(job_input, luigi.Target):
            job_input = {"input": job_input}
        if len(job_input) == 0:  # default requires()
            return []
        if not isinstance(job_input, dict):
            raise ValueError("Input (requires()) must be dict type")
        input_args = []
        for (name, targets) in job_input.items():
            uris = [self._get_uri(target) for target in luigi.task.flatten(targets)]
            if isinstance(targets, dict):
                # If targets is a dict that means it had multiple outputs. In this case make the
                # input args "<input key>-<task output key>"
                names = ["%s-%s" % (name, key) for key in targets.keys()]
            else:
                names = [name] * len(uris)
            for (arg_name, uri) in zip(names, uris):
                input_args.append("--%s=%s" % (arg_name, uri))

        return input_args

    @staticmethod
    def _get_uri(target):
        if hasattr(target, "uri"):
            return target.uri()
        elif isinstance(target, (GCSTarget, GCSFlagTarget)):
            return target.path
        else:
            raise ValueError("Unsupported input Target type: %s" % target.__class__.__name__)
コード例 #8
0
class Baz(luigi.Task):
    bool = luigi.BoolParameter()

    def run(self):
        Baz._val = self.bool
コード例 #9
0
 def testBoolConfigOutranksDefault(self):
     p = luigi.BoolParameter(default=True,
                             config_path=dict(section="foo", name="bar"))
     self.assertEqual(False, _value(p))
コード例 #10
0
ファイル: task.py プロジェクト: hiro-o918/gokart
class TaskOnKart(luigi.Task):
    """
    This is a wrapper class of luigi.Task.

    The key methods of a TaskOnKart are:

    * :py:meth:`make_target` - this makes output target with a relative file path.
    * :py:meth:`make_model_target` - this makes output target for models which generate multiple files to save.
    * :py:meth:`load` - this loads input files of this task.
    * :py:meth:`dump` - this save a object as output of this task.
    """

    workspace_directory = luigi.Parameter(
        default='./resources/',
        description=
        'A directory to set outputs on. Please use a path starts with s3:// when you use s3.',
        significant=False)  # type: str
    local_temporary_directory = luigi.Parameter(
        default='./resources/tmp/',
        description='A directory to save temporary files.',
        significant=False)  # type: str
    rerun = luigi.BoolParameter(
        default=False,
        description=
        'If this is true, this task will run even if all output files exist.',
        significant=False)
    strict_check = luigi.BoolParameter(
        default=False,
        description=
        'If this is true, this task will not run only if all input and output files exist.',
        significant=False)
    modification_time_check = luigi.BoolParameter(
        default=False,
        description=
        'If this is true, this task will not run only if all input and output files exist,'
        ' and all input files are modified before output file are modified.',
        significant=False)
    serialized_task_definition_check = luigi.BoolParameter(
        default=False,
        description=
        'If this is true, this task will not run only if all input and output files exist,'
        ' and this task class is modified.',
        significant=False)
    ignore_serializing_task_definition_error = luigi.BoolParameter(
        default=False,
        description=
        'if this is true, this task ignores error while serializing this task.'
        ' This parameter is effective only when `serialized_task_check` is `True`',
        significant=False)
    delete_unnecessary_output_files = luigi.BoolParameter(
        default=False,
        description='If this is true, delete unnecessary output files.',
        significant=False)
    significant = luigi.BoolParameter(
        default=True,
        description=
        'If this is false, this task is not treated as a part of dependent tasks for the unique id.',
        significant=False)
    fix_random_seed_methods = luigi.ListParameter(
        default=['random.seed', 'numpy.random.seed'],
        description='Fix random seed method list.',
        significant=False)
    fix_random_seed_value = luigi.IntParameter(
        default=None,
        description='Fix random seed method value.',
        significant=False)

    redis_host = luigi.OptionalParameter(
        default=None,
        description='Task lock check is deactivated, when None.',
        significant=False)
    redis_port = luigi.OptionalParameter(
        default=None,
        description='Task lock check is deactivated, when None.',
        significant=False)
    redis_timeout = luigi.IntParameter(
        default=180,
        description='Redis lock will be released after `redis_timeout` seconds',
        significant=False)
    redis_fail_on_collision: bool = luigi.BoolParameter(
        default=False,
        description=
        'True for failing the task immediately when the cache is locked, instead of waiting for the lock to be released',
        significant=False)
    fail_on_empty_dump: bool = ExplicitBoolParameter(
        default=False,
        description='Fail when task dumps empty DF',
        significant=False)

    def __init__(self, *args, **kwargs):
        self._add_configuration(kwargs, 'TaskOnKart')
        # 'This parameter is dumped into "workspace_directory/log/task_log/" when this task finishes with success.'
        self.task_log = dict()
        self.task_unique_id = None
        super(TaskOnKart, self).__init__(*args, **kwargs)
        self._rerun_state = self.rerun
        self._lock_at_dump = True

    def output(self):
        return self.make_target()

    def requires(self):
        tasks = self.make_task_instance_dictionary()
        return tasks or [
        ]  # when tasks is empty dict, then this returns empty list.

    def make_task_instance_dictionary(self) -> Dict[str, 'TaskOnKart']:
        return {
            key: var
            for key, var in vars(self).items() if self.is_task_on_kart(var)
        }

    @staticmethod
    def is_task_on_kart(value):
        return isinstance(
            value,
            TaskOnKart) or (isinstance(value, tuple) and bool(value)
                            and all([isinstance(v, TaskOnKart)
                                     for v in value]))

    @classmethod
    def _add_configuration(cls, kwargs, section):
        config = luigi.configuration.get_config()
        class_variables = dict(TaskOnKart.__dict__)
        class_variables.update(dict(cls.__dict__))
        if section not in config:
            return
        for key, value in dict(config[section]).items():
            if key not in kwargs and key in class_variables:
                kwargs[key] = class_variables[key].parse(value)

    def complete(self) -> bool:
        if self._rerun_state:
            for target in luigi.task.flatten(self.output()):
                target.remove()
            self._rerun_state = False
            return False

        is_completed = all(
            [t.exists() for t in luigi.task.flatten(self.output())])

        if self.strict_check or self.modification_time_check:
            requirements = luigi.task.flatten(self.requires())
            inputs = luigi.task.flatten(self.input())
            is_completed = is_completed and all([
                task.complete() for task in requirements
            ]) and all([i.exists() for i in inputs])

        if not self.modification_time_check or not is_completed or not self.input(
        ):
            return is_completed

        return self._check_modification_time()

    def _check_modification_time(self):
        common_path = set(t.path()
                          for t in luigi.task.flatten(self.input())) & set(
                              t.path()
                              for t in luigi.task.flatten(self.output()))
        input_tasks = [
            t for t in luigi.task.flatten(self.input())
            if t.path() not in common_path
        ]
        output_tasks = [
            t for t in luigi.task.flatten(self.output())
            if t.path() not in common_path
        ]

        input_modification_time = max(
            [target.last_modification_time()
             for target in input_tasks]) if input_tasks else None
        output_modification_time = min(
            [target.last_modification_time()
             for target in output_tasks]) if output_tasks else None

        if input_modification_time is None or output_modification_time is None:
            return True

        # "=" must be required in the following statements, because some tasks use input targets as output targets.
        return input_modification_time <= output_modification_time

    def clone(self, cls=None, **kwargs):
        if cls is None:
            cls = self.__class__

        new_k = {}
        for param_name, param_class in cls.get_params():
            if param_name in {
                    'rerun', 'strict_check', 'modification_time_check'
            }:
                continue

            if param_name in kwargs:
                new_k[param_name] = kwargs[param_name]
            elif hasattr(self, param_name):
                new_k[param_name] = getattr(self, param_name)

        return cls(**new_k)

    def make_target(self,
                    relative_file_path: str = None,
                    use_unique_id: bool = True,
                    processor: Optional[FileProcessor] = None) -> TargetOnKart:
        formatted_relative_file_path = relative_file_path if relative_file_path is not None else os.path.join(
            self.__module__.replace(".", "/"), f"{type(self).__name__}.pkl")
        file_path = os.path.join(self.workspace_directory,
                                 formatted_relative_file_path)
        unique_id = self.make_unique_id() if use_unique_id else None

        redis_params = make_redis_params(
            file_path=file_path,
            unique_id=unique_id,
            redis_host=self.redis_host,
            redis_port=self.redis_port,
            redis_timeout=self.redis_timeout,
            redis_fail_on_collision=self.redis_fail_on_collision)
        return gokart.target.make_target(file_path=file_path,
                                         unique_id=unique_id,
                                         processor=processor,
                                         redis_params=redis_params)

    def make_large_data_frame_target(
        self,
        relative_file_path: str = None,
        use_unique_id: bool = True,
        max_byte=int(2**26)) -> TargetOnKart:
        formatted_relative_file_path = relative_file_path if relative_file_path is not None else os.path.join(
            self.__module__.replace(".", "/"), f"{type(self).__name__}.zip")
        file_path = os.path.join(self.workspace_directory,
                                 formatted_relative_file_path)
        unique_id = self.make_unique_id() if use_unique_id else None
        redis_params = make_redis_params(
            file_path=file_path,
            unique_id=unique_id,
            redis_host=self.redis_host,
            redis_port=self.redis_port,
            redis_timeout=self.redis_timeout,
            redis_fail_on_collision=self.redis_fail_on_collision)
        return gokart.target.make_model_target(
            file_path=file_path,
            temporary_directory=self.local_temporary_directory,
            unique_id=unique_id,
            save_function=gokart.target.LargeDataFrameProcessor(
                max_byte=max_byte).save,
            load_function=gokart.target.LargeDataFrameProcessor.load,
            redis_params=redis_params)

    def make_model_target(self,
                          relative_file_path: str,
                          save_function: Callable[[Any, str], None],
                          load_function: Callable[[str], Any],
                          use_unique_id: bool = True):
        """
        Make target for models which generate multiple files in saving, e.g. gensim.Word2Vec, Tensorflow, and so on.

        :param relative_file_path: A file path to save.
        :param save_function: A function to save a model. This takes a model object and a file path.
        :param load_function: A function to load a model. This takes a file path and returns a model object.
        :param use_unique_id: If this is true, add an unique id to a file base name.
        """
        file_path = os.path.join(self.workspace_directory, relative_file_path)
        assert relative_file_path[
            -3:] == 'zip', f'extension must be zip, but {relative_file_path} is passed.'
        unique_id = self.make_unique_id() if use_unique_id else None
        redis_params = make_redis_params(
            file_path=file_path,
            unique_id=unique_id,
            redis_host=self.redis_host,
            redis_port=self.redis_port,
            redis_timeout=self.redis_timeout,
            redis_fail_on_collision=self.redis_fail_on_collision)
        return gokart.target.make_model_target(
            file_path=file_path,
            temporary_directory=self.local_temporary_directory,
            unique_id=unique_id,
            save_function=save_function,
            load_function=load_function,
            redis_params=redis_params)

    def load(self, target: Union[None, str, TargetOnKart] = None) -> Any:
        def _load(targets):
            if isinstance(targets, list) or isinstance(targets, tuple):
                return [_load(t) for t in targets]
            if isinstance(targets, dict):
                return {k: _load(t) for k, t in targets.items()}
            return targets.load()

        data = _load(self._get_input_targets(target))
        if target is None and isinstance(data, dict) and len(data) == 1:
            return list(data.values())[0]
        return data

    def load_generator(self,
                       target: Union[None, str, TargetOnKart] = None) -> Any:
        def _load(targets):
            if isinstance(targets, list) or isinstance(targets, tuple):
                for t in targets:
                    yield from _load(t)
            elif isinstance(targets, dict):
                for k, t in targets.items():
                    yield from {k: _load(t)}
            else:
                yield targets.load()

        return _load(self._get_input_targets(target))

    def load_data_frame(self,
                        target: Union[None, str, TargetOnKart] = None,
                        required_columns: Optional[Set[str]] = None,
                        drop_columns: bool = False) -> pd.DataFrame:
        def _flatten_recursively(dfs):
            if isinstance(dfs, list):
                return pd.concat([_flatten_recursively(df) for df in dfs])
            else:
                return dfs

        dfs = self.load(target=target)
        if isinstance(dfs, dict) and len(dfs) == 1:
            dfs = list(dfs.values())[0]

        data = _flatten_recursively(dfs)

        required_columns = required_columns or set()
        if data.empty and len(data.index) == 0 and len(required_columns -
                                                       set(data.columns)) > 0:
            return pd.DataFrame(columns=required_columns)
        assert required_columns.issubset(
            set(data.columns)
        ), f'data must have columns {required_columns}, but actually have only {data.columns}.'
        if drop_columns:
            data = data[required_columns]
        return data

    def dump(self, obj, target: Union[None, str, TargetOnKart] = None) -> None:
        PandasTypeConfigMap().check(obj, task_namespace=self.task_namespace)
        if self.fail_on_empty_dump and isinstance(obj, pd.DataFrame):
            assert not obj.empty
        self._get_output_target(target).dump(obj,
                                             lock_at_dump=self._lock_at_dump)

    def make_unique_id(self):
        self.task_unique_id = self.task_unique_id or self._make_hash_id()
        return self.task_unique_id

    def _make_hash_id(self):
        def _to_str_params(task):
            if isinstance(task, TaskOnKart):
                return str(task.make_unique_id()) if task.significant else None
            return task.to_str_params(only_significant=True)

        dependencies = [
            _to_str_params(task)
            for task in luigi.task.flatten(self.requires())
        ]
        dependencies = [d for d in dependencies if d is not None]
        dependencies.append(self.to_str_params(only_significant=True))
        dependencies.append(self.__class__.__name__)
        if self.serialized_task_definition_check:
            try:
                dependencies.append(str(cloudpickle.dumps(self.__class__)))
            except Exception as err:
                if not self.ignore_serializing_task_definition_error:
                    msg = f'{self.__class__.__name__} is not serializable,' \
                          ' you can ignore this error by setting `ignore_serializing_task_error` to `True`'
                    raise TypeError(msg) from err
                logger.warning(
                    f'{self.__class__.__name__} is not serializable, so modification cannot be detected'
                )

        return hashlib.md5(str(dependencies).encode()).hexdigest()

    def _get_input_targets(
        self, target: Union[None, str, TargetOnKart]
    ) -> Union[TargetOnKart, List[TargetOnKart]]:
        if target is None:
            return self.input()
        if isinstance(target, str):
            return self.input()[target]
        return target

    def _get_output_target(
            self, target: Union[None, str, TargetOnKart]) -> TargetOnKart:
        if target is None:
            return self.output()
        if isinstance(target, str):
            return self.output()[target]
        return target

    def get_info(self, only_significant=False):
        params_str = {}
        params = dict(self.get_params())
        for param_name, param_value in self.param_kwargs.items():
            if (not only_significant) or params[param_name].significant:
                if type(params[param_name]) == gokart.TaskInstanceParameter:
                    params_str[param_name] = type(
                        param_value
                    ).__name__ + '-' + param_value.make_unique_id()
                else:
                    params_str[param_name] = params[param_name].serialize(
                        param_value)
        return params_str

    def _get_task_log_target(self):
        return self.make_target(f'log/task_log/{type(self).__name__}.pkl')

    def get_task_log(self) -> Dict:
        target = self._get_task_log_target()
        if self.task_log:
            return self.task_log
        if target.exists():
            return self.load(target)
        return dict()

    @luigi.Task.event_handler(luigi.Event.SUCCESS)
    def _dump_task_log(self):
        self.task_log['file_path'] = [
            target.path() for target in luigi.task.flatten(self.output())
        ]
        self.dump(self.task_log, self._get_task_log_target())

    def _get_task_params_target(self):
        return self.make_target(f'log/task_params/{type(self).__name__}.pkl')

    def get_task_params(self) -> Dict:
        target = self._get_task_log_target()
        if target.exists():
            return self.load(target)
        return dict()

    @luigi.Task.event_handler(luigi.Event.START)
    def _set_random_seed(self):
        random_seed = self._get_random_seed()
        seed_methods = self.try_set_seed(self.fix_random_seed_methods,
                                         random_seed)
        self.dump({
            'seed': random_seed,
            'seed_methods': seed_methods
        }, self._get_random_seeds_target())

    def _get_random_seeds_target(self):
        return self.make_target(f'log/random_seed/{type(self).__name__}.pkl')

    @staticmethod
    def try_set_seed(methods: List[str], random_seed: int) -> List[str]:
        success_methods = []
        for method_name in methods:
            try:
                for i, x in enumerate(method_name.split('.')):
                    if i == 0:
                        m = import_module(x)
                    else:
                        m = getattr(m, x)
                m(random_seed)
                success_methods.append(method_name)
            except ModuleNotFoundError:
                pass
            except AttributeError:
                pass
        return success_methods

    def _get_random_seed(self):
        if self.fix_random_seed_value:
            return self.fix_random_seed_value
        return int(self.make_unique_id(), 16) % (2**32 - 1
                                                 )  # maximum numpy.random.seed

    @luigi.Task.event_handler(luigi.Event.START)
    def _dump_task_params(self):
        self.dump(self.to_str_params(only_significant=True),
                  self._get_task_params_target())

    def _get_processing_time_target(self):
        return self.make_target(
            f'log/processing_time/{type(self).__name__}.pkl')

    def get_processing_time(self) -> str:
        target = self._get_processing_time_target()
        if target.exists():
            return self.load(target)
        return 'unknown'

    @luigi.Task.event_handler(luigi.Event.PROCESSING_TIME)
    def _dump_processing_time(self, processing_time):
        self.dump(processing_time, self._get_processing_time_target())

    @classmethod
    def restore(cls, unique_id):
        params = TaskOnKart().make_target(
            f'log/task_params/{cls.__name__}_{unique_id}.pkl',
            use_unique_id=False).load()
        return cls.from_str_params(params)

    @luigi.Task.event_handler(luigi.Event.FAILURE)
    def _log_unique_id(self, exception):
        logger.info(
            f'FAILURE:\n    task name={type(self).__name__}\n    unique id={self.make_unique_id()}'
        )

    @luigi.Task.event_handler(luigi.Event.START)
    def _dump_module_versions(self):
        self.dump(self._get_module_versions(),
                  self._get_module_versions_target())

    def _get_module_versions_target(self):
        return self.make_target(
            f'log/module_versions/{type(self).__name__}.txt')

    def _get_module_versions(self) -> str:
        module_versions = []
        for x in set([
                x.split('.')[0] for x in globals().keys()
                if isinstance(x, types.ModuleType) and '_' not in x
        ]):
            module = import_module(x)
            if '__version__' in dir(module):
                if type(module.__version__) == str:
                    version = module.__version__.split(" ")[0]
                else:
                    version = '.'.join([str(v) for v in module.__version__])
                module_versions.append(f'{x}=={version}')
        return '\n'.join(module_versions)

    def __repr__(self):
        """
        Build a task representation like `MyTask(param1=1.5, param2='5', data_task=DataTask(id=35tyi))`
        """
        params = self.get_params()
        param_values = self.get_param_values(params, [], self.param_kwargs)

        # Build up task id
        repr_parts = []
        param_objs = dict(params)
        for param_name, param_value in param_values:
            param_obj = param_objs[param_name]
            if param_obj.significant:
                repr_parts.append(
                    f'{param_name}={self._make_representation(param_obj, param_value)}'
                )

        task_str = f'{self.get_task_family()}({", ".join(repr_parts)})'
        return task_str

    def _make_representation(self, param_obj: luigi.Parameter, param_value):
        if isinstance(param_obj, TaskInstanceParameter):
            return f'{param_value.get_task_family()}({param_value.make_unique_id()})'
        if isinstance(param_obj, ListTaskInstanceParameter):
            return f"[{', '.join(f'{v.get_task_family()}({v.make_unique_id()})' for v in param_value)}]"
        return param_obj.serialize(param_value)
コード例 #11
0
class SpawnTestDockerEnvironment(StoppableTask):
    logger = logging.getLogger('luigi-interface')

    environment_name = luigi.Parameter()
    reuse_database_setup = luigi.BoolParameter(False, significant=False)
    reuse_database = luigi.BoolParameter(False, significant=False)
    reuse_test_container = luigi.BoolParameter(False, significant=False)
    database_port_forward = luigi.OptionalParameter(None, significant=False)
    bucketfs_port_forward = luigi.OptionalParameter(None, significant=False)
    max_start_attempts = luigi.IntParameter(2, significant=False)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._prepare_outputs()
        self.test_container_name = f"""test_container_{self.environment_name}"""
        self.db_container_name = f"""db_container_{self.environment_name}"""
        self.network_name = f"""db_network_{self.environment_name}"""

    def _prepare_outputs(self):
        self._environment_info_target = luigi.LocalTarget(
            "%s/info/environment/%s/environment_info" %
            (build_config().output_directory, self.environment_name))
        if self._environment_info_target.exists():
            self._environment_info_target.remove()

    def output(self):
        return {
            ENVIRONMENT_INFO: self._environment_info_target,
        }

    def run_task(self):
        test_environment_info = yield from self.attempt_database_start()
        test_environment_info_dict = test_environment_info.to_dict()
        yield from self.setup_test_database(test_environment_info_dict)
        self.write_output(test_environment_info)

    def attempt_database_start(self):
        is_database_ready = False
        attempt = 0
        while not is_database_ready and attempt < self.max_start_attempts:
            database_info, is_database_ready, test_container_info = \
                yield from self.start_database(attempt)
            attempt += 1
        if not is_database_ready and not attempt < self.max_start_attempts:
            raise Exception(
                f"Maximum attempts {attempt} to start the database reached.")
        test_environment_info = \
            EnvironmentInfo(name=self.environment_name,
                            database_info=database_info,
                            test_container_info=test_container_info)
        return test_environment_info

    def start_database(self, attempt):
        network_info_dict = yield from self.create_network(attempt)
        database_info, database_info_dict, \
        test_container_info, test_container_info_dict = \
            yield from self.spawn_database_and_test_container(network_info_dict, attempt)
        is_database_ready = yield from self.wait_for_database(
            database_info_dict, test_container_info_dict, attempt)
        return database_info, is_database_ready, test_container_info

    def create_network(self, attempt):
        docker_network_output = \
            yield PrepareDockerNetworkForTestEnvironment(
                environment_name=self.environment_name,
                test_container_name=self.test_container_name,
                db_container_name=self.db_container_name,
                network_name=self.network_name,
                reuse=self.reuse_database,
                attempt=attempt
            )
        network_info, network_info_dict = \
            self.get_network_info(docker_network_output)
        return network_info_dict

    def wait_for_database(self, database_info_dict, test_container_info_dict,
                          attempt):
        database_ready_target = \
            yield WaitForTestDockerDatabase(environment_name=self.environment_name,
                                            test_container_info_dict=test_container_info_dict,
                                            database_info_dict=database_info_dict,
                                            attempt=attempt)
        with database_ready_target.open("r") as file:
            is_database_ready = file.read() == str(True)
        return is_database_ready

    def setup_test_database(self, test_environment_info_dict):
        # TODO check if database is setup
        yield [
            UploadExaJDBC(
                environment_name=self.environment_name,
                test_environment_info_dict=test_environment_info_dict,
                reuse_uploaded=self.reuse_database_setup),
            UploadVirtualSchemaJDBCAdapter(
                environment_name=self.environment_name,
                test_environment_info_dict=test_environment_info_dict,
                reuse_uploaded=self.reuse_database_setup),
            PopulateEngineSmallTestDataToDatabase(
                environment_name=self.environment_name,
                test_environment_info_dict=test_environment_info_dict,
                reuse_data=self.reuse_database_setup)
        ]

    def spawn_database_and_test_container(self, network_info_dict, attempt):
        database_and_test_container_output = \
            yield {
                "test_container": SpawnTestContainer(
                    environment_name=self.environment_name,
                    test_container_name=self.test_container_name,
                    network_info_dict=network_info_dict,
                    ip_address_index_in_subnet=1,
                    reuse_test_container=self.reuse_test_container,
                    attempt=attempt),
                "database": SpawnTestDockerDatabase(
                    environment_name=self.environment_name,
                    db_container_name=self.db_container_name,
                    network_info_dict=network_info_dict,
                    ip_address_index_in_subnet=0,
                    database_port_forward=self.database_port_forward,
                    bucketfs_port_forward=self.bucketfs_port_forward,
                    reuse_database=self.reuse_database,
                    attempt=attempt
                )
            }
        test_container_info, test_container_info_dict = \
            self.get_test_container_info(database_and_test_container_output)
        database_info, database_info_dict = \
            self.get_database_info(database_and_test_container_output)
        return database_info, database_info_dict, \
               test_container_info, test_container_info_dict

    def get_network_info(self, network_info_target):
        network_info = \
            DependencyDockerNetworkInfoCollector().get_from_sinlge_input(network_info_target)
        network_info_dict = network_info.to_dict()
        return network_info, network_info_dict

    def get_test_container_info(self, input: Dict[str, Dict[str,
                                                            LocalTarget]]):
        container_info_of_dependencies = \
            DependencyContainerInfoCollector().get_from_dict_of_inputs(input)
        test_container_info = container_info_of_dependencies["test_container"]
        test_container_info_dict = test_container_info.to_dict()
        return test_container_info, test_container_info_dict

    def get_database_info(self, input: Dict[str, Dict[str, LocalTarget]]):
        database_info_of_dependencies = \
            DependencyDatabaseInfoCollector().get_from_dict_of_inputs(input)
        database_info = database_info_of_dependencies["database"]
        database_info_dict = database_info.to_dict()
        return database_info, database_info_dict

    def write_output(self, environment_info: EnvironmentInfo):
        with self.output()[ENVIRONMENT_INFO].open("w") as file:
            file.write(environment_info.to_json())
コード例 #12
0
class ParentIdCollectTask(luigi.Task):
    '''Download tar file of csvs and append parent_ids to the organizations table.

    Args:
        date (datetime): Datetime used to label the outputs
        _routine_id (str): String used to label the AWS task
        db_config_env (str): The output database envariable
        db_config_path (str): The output database configuration
        insert_batch_size (int): number of rows to insert into the db in a batch
    '''
    date = luigi.DateParameter()
    _routine_id = luigi.Parameter()
    test = luigi.BoolParameter()
    db_config_env = luigi.Parameter()
    db_config_path = luigi.Parameter()
    insert_batch_size = luigi.IntParameter(default=500)

    def requires(self):
        yield HealthLabelTask(date=self.date,
                              _routine_id=self._routine_id,
                              test=self.test,
                              insert_batch_size=self.insert_batch_size,
                              db_config_env=self.db_config_env,
                              bucket='nesta-crunchbase-models',
                              vectoriser_key='vectoriser.pickle',
                              classifier_key='clf.pickle')

    def output(self):
        '''Points to the output database engine'''
        db_config = misctools.get_config(self.db_config_path, "mysqldb")
        db_config["database"] = 'dev' if self.test else 'production'
        db_config["table"] = "Crunchbase <dummy>"  # Note, not a real table
        update_id = "CrunchbaseParentIdCollect_{}".format(self.date)
        return MySqlTarget(update_id=update_id, **db_config)

    def run(self):
        # database setup
        database = 'dev' if self.test else 'production'
        logging.warning(f"Using {database} database")
        self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database)

        # collect file
        logging.info(f"Collecting org_parents from crunchbase tar")
        org_parents = get_files_from_tar(['org_parents'])[0]
        logging.info(f"{len(org_parents)} parent ids in crunchbase export")

        # collect previously processed orgs
        logging.info("Extracting previously processed organisations")
        with db_session(self.engine) as session:
            processed_orgs = session.query(Organization.id,
                                           Organization.parent_id).all()
        all_orgs = {org for (org, _) in processed_orgs}
        logging.info(f"{len(all_orgs)} total orgs in database")
        processed_orgs = {
            org
            for (org, parent_id) in processed_orgs if parent_id is not None
        }
        logging.info(f"{len(processed_orgs)} previously processed orgs")

        # reformat into a list of dicts, removing orgs that already have a parent_id
        # or are missing from the database
        org_parents = org_parents[['uuid', 'parent_uuid']]
        org_parents.columns = ['id', 'parent_id']
        org_parents = org_parents[org_parents['id'].isin(all_orgs)]
        org_parents = org_parents[~org_parents['id'].isin(processed_orgs)]
        org_parents = org_parents.to_dict(orient='records')
        logging.info(f"{len(org_parents)} organisations to update in MYSQL")

        # insert parent_ids into db in batches
        for count, batch in enumerate(
                split_batches(org_parents, self.insert_batch_size), 1):
            with db_session(self.engine) as session:
                session.bulk_update_mappings(Organization, batch)
            logging.info(
                f"{count} batch{'es' if count > 1 else ''} written to db")
            if self.test and count > 1:
                logging.info("Breaking after 2 batches while in test mode")
                break

        # mark as done
        logging.warning("Task complete")
        self.output().touch()
コード例 #13
0
class SpawnTestContainer(StoppableTask):
    environment_name = luigi.Parameter()
    test_container_name = luigi.Parameter()
    network_info_dict = luigi.DictParameter(significant=False)
    ip_address_index_in_subnet = luigi.IntParameter(significant=False)
    attempt = luigi.IntParameter(1)
    reuse_test_container = luigi.BoolParameter(False, significant=False)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._client = docker_client_config().get_client()

        if self.ip_address_index_in_subnet < 0:
            raise Exception(
                "ip_address_index_in_subnet needs to be greater than 0 got %s"
                % self.ip_address_index_in_subnet)
        self._prepare_outputs()

    def _prepare_outputs(self):
        self._test_container_info_target = luigi.LocalTarget(
            "%s/info/environment/%s/test-container/%s/%s/container_info"
            % (build_config().output_directory,
               self.environment_name,
               self.test_container_name,
               self.attempt))
        if self._test_container_info_target.exists():
            self._test_container_info_target.remove()

    def output(self):
        return {CONTAINER_INFO: self._test_container_info_target}

    def requires_tasks(self):
        return {"test_container_image": DockerTestContainerBuild(),
                "export_directory": CreateExportDirectory()}

    def run_task(self):
        network_info = DockerNetworkInfo.from_dict(self.network_info_dict)
        subnet = netaddr.IPNetwork(network_info.subnet)
        ip_address = str(subnet[2 + self.ip_address_index_in_subnet])
        container_info = None
        if network_info.reused and self.reuse_test_container:
            container_info = self.try_to_reuse_test_container(ip_address, network_info)
        if container_info is None:
            container_info = self.create_test_container(ip_address, network_info)
        with self.output()[CONTAINER_INFO].open("w") as file:
            file.write(container_info.to_json())

    def try_to_reuse_test_container(self, ip_address: str, network_info: DockerNetworkInfo) -> ContainerInfo:
        self.logger.info("Task %s: Try to reuse test container %s",
                         self.__repr__(), self.test_container_name)
        container_info = None
        try:
            container_info = self.get_container_info(ip_address, network_info)
        except Exception as e:
            self.logger.warning("Task %s: Tried to reuse test container %s, but got Exeception %s. "
                                "Fallback to create new database.", self.__repr__(), self.test_container_name, e)
        return container_info

    def create_test_container(self, ip_address, network_info) -> ContainerInfo:
        test_container_image_info = self.get_test_container_image_info(self.input())
        # A later task which uses the test_container needs the exported container,
        # but to access exported container from inside the test_container,
        # we need to mount the release directory into the test_container.
        exports_host_path = pathlib.Path(self.get_release_directory()).absolute()
        tests_host_path = pathlib.Path("./tests").absolute()
        test_container = \
            self._client.containers.create(
                image=test_container_image_info.get_target_complete_name(),
                name=self.test_container_name,
                network_mode=None,
                command="sleep infinity",
                detach=True,
                volumes={
                    exports_host_path: {
                        "bind": "/exports",
                        "mode": "ro"
                    },
                    tests_host_path: {
                        "bind": "/tests_src",
                        "mode": "ro"
                    }
                })
        self._client.networks.get(network_info.network_name).connect(test_container, ipv4_address=ip_address)
        test_container.start()
        test_container.exec_run(cmd="cp -r /tests_src /tests")
        container_info = self.get_container_info(ip_address, network_info)
        return container_info

    def get_container_info(self, ip_address, network_info:DockerNetworkInfo)->ContainerInfo:
        test_container = self._client.containers.get(self.test_container_name)
        if test_container.status != "running":
            raise Exception(f"Container {self.test_container_name} not running")
        container_info = ContainerInfo(container_name=self.test_container_name,
                                       ip_address=ip_address,
                                       network_info=network_info)
        return container_info

    def get_release_directory(self):
        return pathlib.Path(self.input()["export_directory"].path).absolute().parent

    def get_test_container_image_info(self, input: Dict[str, LocalTarget]) -> ImageInfo:
        with input["test_container_image"].open("r") as f:
            jsonpickle.set_preferred_backend('simplejson')
            jsonpickle.set_encoder_options('simplejson', sort_keys=True, indent=4)
            object = jsonpickle.decode(f.read())
        return object["test-container"]["test-container"]
コード例 #14
0
        return significance + class_name + default + description

    luigi.parameter.Parameter.__repr__ = parameter_repr

    def assertIn(needle, haystack):
        """
        We test repr of Parameter objects, since it'll be used for readthedocs
        """
        assert needle in haystack

    # TODO: find a better place to put this!
    assertIn('IntParameter', repr(luigi.IntParameter()))
    assertIn('defaults to 37', repr(luigi.IntParameter(default=37)))
    assertIn('hi mom', repr(luigi.IntParameter(description='hi mom')))
    assertIn('Insignificant BoolParameter',
             repr(luigi.BoolParameter(significant=False)))
except ImportError:
    pass


def _warn_node(self, msg, node):
    """
    Mute warnings that are like ``WARNING: nonlocal image URI found: https://img. ...``

    Solution was found by googling, copied it from SO:

    http://stackoverflow.com/questions/12772927/specifying-an-online-image-in-sphinx-restructuredtext-format
    """
    if not msg.startswith('nonlocal image URI found:'):
        self._warnfunc(msg, '%s:%s' % get_source_line(node))
コード例 #15
0
class WritePainteraMetadata(luigi.Task):
    tmp_folder = luigi.Parameter()
    path = luigi.Parameter()
    raw_key = luigi.Parameter()
    label_group = luigi.Parameter()
    scale_factors = luigi.ListParameter()
    label_scale = luigi.IntParameter()
    is_label_multiset = luigi.BoolParameter()
    resolution = luigi.ListParameter()
    offset = luigi.ListParameter()
    max_id = luigi.IntParameter()
    dependency = luigi.TaskParameter()

    def _write_log(self, msg):
        log_file = self.output().path
        with open(log_file, 'a') as f:
            f.write('%s: %s\n' % (str(datetime.now()), msg))

    def requires(self):
        return self.dependency

    def _write_downsampling_factors(self, group):
        # get the actual scales we have in the segmentation
        scale_factors = [[1, 1, 1]] + list(
            self.scale_factors[self.label_scale + 1:])
        effective_scale = [1, 1, 1]
        # write the scale factors
        for scale, scale_factor in enumerate(scale_factors):

            # don't write attrs for the original dataset
            if scale == 0:
                continue

            ds = group['s%i' % scale]
            effective_scale = [
                sf * eff for sf, eff in zip(scale_factor, effective_scale)
            ]
            # we need to reverse the scale factors because paintera has axis order
            # XYZ and we have axis order ZYX
            ds.attrs['downsamplingFactors'] = effective_scale[::-1]

    def run(self):

        # compute the correct resolutions for raw data and labels
        label_resolution = [
            res * sf for res, sf in zip(self.resolution, self.scale_factors[
                self.label_scale])
        ]
        raw_resolution = self.resolution

        with z5py.File(self.path) as f:
            # write metadata for the top-level label group
            label_group = f[self.label_group]
            label_group.attrs['painteraData'] = {'type': 'label'}
            label_group.attrs['maxId'] = self.max_id
            # add the metadata referencing the label to block lookup
            scale_ds_pattern = os.path.join(self.label_group,
                                            'label-to-block-mapping', 's%d')
            label_group.attrs["labelBlockLookup"] = {
                "type": "n5-filesystem",
                "root": os.path.abspath(os.path.realpath(self.path)),
                "scaleDatasetPattern": scale_ds_pattern
            }
            # write metadata for the label-data group
            data_group = f[os.path.join(self.label_group, 'data')]
            data_group.attrs['maxId'] = self.max_id
            data_group.attrs['multiScale'] = True
            # we revese resolution and offset because java n5 uses axis
            # convention XYZ and we use ZYX
            data_group.attrs['offset'] = self.offset[::-1]
            data_group.attrs['resolution'] = label_resolution[::-1]
            data_group.attrs['isLabelMultiset'] = self.is_label_multiset
            self._write_downsampling_factors(data_group)
            # add metadata for unique labels group
            unique_group = f[os.path.join(self.label_group, 'unique-labels')]
            unique_group.attrs['multiScale'] = True
            self._write_downsampling_factors(unique_group)
            # add metadata for label to block mapping
            mapping_group = f[os.path.join(self.label_group,
                                           'label-to-block-mapping')]
            mapping_group.attrs['multiScale'] = True
            self._write_downsampling_factors(mapping_group)
            # add metadata for the raw data
            raw_group = f[self.raw_key]
            raw_group.attrs['resolution'] = raw_resolution[::-1]
        self._write_log('write metadata successfull')

    def output(self):
        return luigi.LocalTarget(
            os.path.join(self.tmp_folder, 'write_paintera_metadata.log'))
コード例 #16
0
class ContainerTask(luigi.Task):
    """Base class to run containers."""

    no_remove_finished = luigi.BoolParameter(
        default=False, description="Don't remove containers "
                                   "that finished successfully.")

    CLIENT = None

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._client = self.get_client()
        self._container = None
        self.u_name = None
        self._retry_count = 0
        self._log = []

    def get_client(self)->ContainerClient:
        """Retrieve execution engine client.

        The client object will be saved to the private attribute `_client`.
        Override this method in case your client needs some special
        initialization.

        Returns
        -------
            client: object
                client to manage containers
        """
        return self.CLIENT()

    def get_logs(self):
        """Return container logs.

        This method returns container logs as a list of lines.
        It will only work if `log_stream` method is implemented.

        It is especially useful if a task does not create a file in a
        filesystem e.g. it might do some database operation in which
        case the log can be written to a file which would also serve
        as a flag that the task finished successfully. For this case
        you will have to override the `run` method to save the logs
        in the end.

        Returns
        -------
            logs: list of str
        """
        return self._log

    @property
    def name(self):
        """We name the resource with luigi's id.

        This id is based on a hash of a tasks parameters and helps avoid running
        the same task twice. If a task with this names already exists and failed
        it will append a 'retry-<NUMBER>' to the name.
        """
        task_id = self.task_id
        if len(task_id) > 53:
            name_components = task_id.split('_')
            name, param_hash = name_components[0], name_components[-1]
            return '-'.join([name[:43], param_hash]).lower()
        else:
            return task_id.lower().replace('_', '-')

    @property
    def command(self):
        """The command to be executed by the container."""
        raise NotImplementedError("Docker task must specify command")

    @property
    def image(self):
        """Which image to use to create the container."""
        raise NotImplementedError("Docker tasks must specify image")

    @property
    def labels(self):
        return dict(luigi_retries=str(self._retry_count),
                    luigi_task_id=self.name)

    @property
    def configuration(self):
        """Container configuration dictionary.

        Should return a dictionary as accepted by docker-py's run method
        see https://docker-py.readthedocs.io/en/stable/containers.html for more
        information which keys are accepted.

        It should be translated into other execution engines format by the
        corresponding subclass.
        """
        default = {'labels': self.labels}
        default.update({
            'environment': CONTAINER_TASK_ENV
        })
        default.update({
            'volumes': CONTAINER_TASK_VOLUMES
        })
        default.update({
            'network': CONTAINER_TASK_NET
        })
        return default

    def _set_name(self):
        if self._retry_count:
            self.u_name = '{}-retry-{}' \
                .format(self.name, self._retry_count + 1)
        else:
            self.u_name = self.name

    def run(self):
        """Actually submit and run task as a container."""
        try:
            self._run_and_track_task()
        finally:
            if self._container:
                self._client.stop_container(self._container)

    def _run_and_track_task(self):
        self._retry_count = self._client.get_retry_count(self.name)
        self._set_name()
        self._container = self._client.run_container(
            self.image,
            self.u_name,
            self.command,
            self.configuration,
        )
        self._log = []
        log_stream = self._client.log_generator(self._container)
        if log_stream is not None:
            for line in log_stream:
                self._log.append(line.decode().strip())
                getLogger('luigi-interface').info(self._log[-1])
                self.set_status_message('\n'.join(self._log))
        exit_info = self._client.get_exit_info(self._container)
        if exit_info[0] == 0:
            if not self.no_remove_finished:
                self._client.remove_container(self._container)
            self._container = None
        else:
            raise RuntimeError(
                "Container exited with status code:"
                " {} and msg: {}".format(exit_info[0],
                                         exit_info[1]))
コード例 #17
0
ファイル: hists.py プロジェクト: cms-btv-pog/jet-tagging-sf
class WriteHistograms(DatasetTask, AnalysisSandboxTask, GridWorkflow,
                      law.LocalWorkflow):

    iteration = luigi.IntParameter(default=0,
                                   description="iteration of the scale factor "
                                   "calculation, starting at zero, default: 0")
    final_it = luigi.BoolParameter(
        description="Flag for the final iteration of the scale factor "
        "calculation.")
    variable_tags = CSVParameter(
        default=[],
        description="Only consider variables with one or more of "
        "the given tags. Use all if empty.")
    category_tags = CSVParameter(
        default=[],
        description="Only consider categories whose top-level "
        "category has one or more of the given tags. Use all if empty.")
    used_shifts = CSVParameter(
        default=[]
    )  # needs to be named differently from the wrapper task parameter
    binning = CSVParameter(
        default=[],
        cls=luigi.FloatParameter,
        description="Overwrite default binning "
        "of variables. If exactly three values are provided, they are interpreted as a tuple of (n_bins, min, max)."
    )

    b_tagger = luigi.Parameter(default="deepcsv",
                               description="Name of the b-tagger to use.")
    optimize_binning = luigi.BoolParameter(
        description="Use optimized discriminant binning.")

    file_merging = "trees"

    workflow_run_decorators = [law.decorator.notify]

    sandbox = "singularity::/cvmfs/singularity.opensciencegrid.org/cmssw/cms:rhel7-m20200612"
    req_sandbox = "slc7"

    def __init__(self, *args, **kwargs):
        super(WriteHistograms, self).__init__(*args, **kwargs)
        # set shifts
        if self.dataset_inst.is_data:
            shifts = {"nominal"}
        else:
            jes_sources = self.config_inst.get_aux("jes_sources_{}".format(
                self.config_inst.get_aux("jes_scheme")))
            shifts = {"nominal"} | format_shifts(jes_sources, prefix="jes")
            if self.iteration > 0:
                shifts = shifts | format_shifts([
                    "lf", "hf", "lf_stats1", "lf_stats2", "hf_stats1",
                    "hf_stats2"
                ])
                if self.final_it:  # add c shifts
                    shifts = shifts | format_shifts(["c_stats1", "c_stats2"])

        if len(self.used_shifts) == 0:
            self.shifts = shifts
        elif any([shift not in shifts for shift in self.used_shifts]):
            raise ValueError("Unknown shift in {}".format(self.used_shifts))
        else:
            self.shifts = self.used_shifts

    def workflow_requires(self):
        from analysis.tasks.measurement import BundleScaleFactors

        reqs = super(WriteHistograms, self).workflow_requires()

        if not self.cancel_jobs and not self.cleanup_jobs:
            reqs["meta"] = MergeMetaData.req(
                self,
                version=self.get_version(MergeMetaData),
                _prefer_cli=["version"])
            if self.dataset_inst.is_mc:
                reqs["pu"] = CalculatePileupWeights.req(self)
            if not self.pilot:
                reqs["tree"] = MergeTrees.req(
                    self,
                    cascade_tree=-1,
                    version=self.get_version(MergeTrees),
                    _prefer_cli=["version"])
            if self.iteration > 0:
                reqs["sf"] = BundleScaleFactors.req(
                    self,
                    iteration=self.iteration - 1,
                    fix_normalization=self.final_it,
                    include_cshifts=self.final_it,
                    version=self.get_version(BundleScaleFactors),
                    _prefer_cli=["version"])
            if self.optimize_binning:
                from analysis.tasks.util import OptimizeBinning  # prevent circular import
                reqs["binning"] = OptimizeBinning.req(
                    self,
                    version=self.get_version(OptimizeBinning),
                    _prefer_cli=["version"])

        return reqs

    def requires(self):
        from analysis.tasks.measurement import BundleScaleFactors

        reqs = {
            "tree":
            MergeTrees.req(self,
                           cascade_tree=self.branch,
                           branch=0,
                           version=self.get_version(MergeTrees),
                           _prefer_cli=["version", "workflow"]),
            "meta":
            MergeMetaData.req(self,
                              version=self.get_version(MergeMetaData),
                              _prefer_cli=["version"]),
        }
        if self.dataset_inst.is_mc:
            reqs["pu"] = CalculatePileupWeights.req(self)
        if self.iteration > 0:
            reqs["sf"] = BundleScaleFactors.req(
                self,
                iteration=self.iteration - 1,
                fix_normalization=self.final_it,
                include_cshifts=self.final_it,
                version=self.get_version(BundleScaleFactors),
                _prefer_cli=["version"])
        if self.optimize_binning:
            from analysis.tasks.util import OptimizeBinning  # prevent circular import
            reqs["binning"] = OptimizeBinning.req(
                self,
                version=self.get_version(OptimizeBinning),
                _prefer_cli=["version"])
        return reqs

    def store_parts(self):
        binning_part = "optimized" if self.optimize_binning else "default"
        variable_part = "_".join(
            self.variable_tags) if self.variable_tags else "all"
        shift_part = "_".join(self.used_shifts) if self.used_shifts else "all"
        return super(WriteHistograms, self).store_parts() + (self.b_tagger,) + (self.iteration,) \
            + (variable_part,) + (shift_part,) + (binning_part,)

    def output(self):
        return self.wlcg_target("hists_{}.root".format(self.branch))

    def get_jec_identifier(self, shift):
        if shift.startswith("jes"):
            return "_" + shift
        else:
            return ""

    def get_pileup_weighter(self, inp):
        with inp.load() as pu_file:
            pu_hist = pu_file.Get("pileup_weights")
            pu_values = [
                pu_hist.GetBinContent(i)
                for i in range(1,
                               pu_hist.GetNbinsX() + 1)
            ]
            pu_values = [
                value if (value < 1000) else 1. for value in pu_values
            ]  # TODO: Temporary, due to high pu weights in 2018 data

        def add_branch(extender):
            extender.add_branch("pu_weight", unpack="pu")

        def add_value(entry):
            # some events have inf pileup, skip them
            weight = 1.
            pu = entry.pu[0]
            if np.isfinite(pu):
                pu_idx = int(pu) - 1
                if 0 <= pu_idx < len(pu_values):
                    weight = pu_values[pu_idx]
            entry.pu_weight[0] = weight

        return add_branch, add_value

    def get_scale_factor_weighter(self, inp, shift, nominal_sfs=None):
        sf_hists = {}
        input_files = [inp]
        # c scale factor files have no histograms for hf/lf, so use nominal ones
        if nominal_sfs is not None:
            input_files.append(nominal_sfs)

        for input_file in input_files:
            with input_file.load() as sfs:
                shift_dir = sfs.Get(shift)
                for category in shift_dir.GetListOfKeys():
                    category_dir = shift_dir.Get(category.GetName())
                    hist = category_dir.Get("sf")
                    # decouple from open file
                    hist.SetDirectory(0)

                    if category.GetName() not in sf_hists:
                        sf_hists[category.GetName()] = hist
                    else:
                        raise KeyError("Duplicate category {} in scale factor "
                                       "weighter.".format(category.GetName()))

        btag_var = self.config_inst.get_aux("btaggers")[
            self.b_tagger]["variable"]
        identifier = self.get_jec_identifier(shift)

        def add_branch(extender):
            unpack_vars = sum([[
                "jet{}_pt{}".format(idx, identifier), "jet{}_flavor{}".format(
                    idx, identifier), "jet{}_eta{}".format(idx, identifier),
                "jet{}_{}{}".format(idx, btag_var, identifier)
            ] for idx in range(1, 5)], [])
            extender.add_branch("scale_factor_lf_{}".format(shift),
                                unpack=unpack_vars)
            extender.add_branch("scale_factor_c_{}".format(shift),
                                unpack=unpack_vars)
            extender.add_branch("scale_factor_hf_{}".format(shift),
                                unpack=unpack_vars)

        def add_value(entry):
            scale_factor_lf = 1.
            scale_factor_c = 1.
            scale_factor_hf = 1.
            for jet_idx in range(1, 5):
                jet_pt = getattr(entry,
                                 "jet{}_pt{}".format(jet_idx, identifier))[0]
                jet_eta = getattr(entry,
                                  "jet{}_eta{}".format(jet_idx, identifier))[0]
                jet_flavor = getattr(
                    entry, "jet{}_flavor{}".format(jet_idx, identifier))[0]
                jet_btag = getattr(
                    entry, "jet{}_{}{}".format(jet_idx, btag_var,
                                               identifier))[0]

                # stop when number of jets is exceeded
                if jet_flavor < -999.:
                    break

                # find category in which the scale factor of the jet was computed to get correct histogram
                if abs(jet_flavor) == 5:
                    region = "hf"
                elif abs(jet_flavor) == 4:
                    region = "c"
                else:
                    region = "lf"

                # nominal c scale factors are 1
                if region == "c" and not shift.startswith("c_stat"):
                    continue

                category = self.category_getter.get_category(
                    jet_pt, abs(jet_eta), region)

                # get scale factor
                sf_hist = sf_hists[category.name]
                bin_idx = sf_hist.FindBin(jet_btag)
                scale_factor = sf_hist.GetBinContent(bin_idx)
                scale_factor = max([0., scale_factor])

                if abs(jet_flavor) == 5:
                    scale_factor_hf *= scale_factor
                elif abs(jet_flavor) == 4:
                    scale_factor_c *= scale_factor
                else:
                    scale_factor_lf *= scale_factor

            getattr(entry,
                    "scale_factor_lf_{}".format(shift))[0] = scale_factor_lf
            getattr(entry,
                    "scale_factor_c_{}".format(shift))[0] = scale_factor_c
            getattr(entry,
                    "scale_factor_hf_{}".format(shift))[0] = scale_factor_hf

        return add_branch, add_value

    @law.decorator.notify
    def run(self):
        import ROOT

        inp = self.input()
        outp = self.output()
        outp.parent.touch(0o0770)

        self.category_getter = CategoryGetter(self.config_inst, self.b_tagger)

        # get child categories
        categories = []

        for category in self.config_inst.categories:
            # only consider top-level categories with at least one given tag if specified
            if len(self.category_tags) > 0 and not category.has_tag(
                    self.category_tags, mode=any):
                continue
            # for intermediate iterations, skip merged categories not used for measurement
            # (to improve performance)
            if not self.final_it:
                if category.has_tag("merged") and not category.get_aux(
                        "phase_space") == "measure":
                    continue
            # recurse through all children of category, add leaf categories
            for cat, children in walk_categories(category):
                if not children:
                    # only use categories matching the task config
                    if cat.get_aux("config", None) != self.config_inst.name:
                        continue
                    # only use categories for the chosen b-tag algorithm
                    if cat.has_tag(self.b_tagger):
                        channel = cat.get_aux("channel")
                        categories.append((channel, cat))

        categories = list(set(categories))

        # get processes
        if len(self.dataset_inst.processes) != 1:
            raise NotImplementedError(
                "only datasets with exactly one linked process can be"
                " handled, got {}".format(len(self.dataset_inst.processes)))
        processes = list(self.dataset_inst.processes.values())

        # build a progress callback
        progress = self.create_progress_callback(len(categories))

        # open the output file
        with outp.localize("w") as tmp:
            with tmp.dump("RECREATE") as output_file:
                with self.publish_step(
                        "creating root output file directories ..."):
                    process_dirs = {}
                    for _, category in categories:
                        output_file.cd()
                        category_dir = output_file.mkdir(category.name)
                        for process in processes:
                            category_dir.cd()
                            process_dir = category_dir.mkdir(process.name)
                            process_dir.Write()
                            process_dirs[(category.name,
                                          process.name)] = process_dir

                # open the input file and get the tree
                # as we need to extend the tree with custom weights, we do not cache the file
                with inp["tree"].load("UPDATE", cache=False) as input_file:
                    tree = input_file.Get("tree")
                    self.publish_message("{} events in tree".format(
                        tree.GetEntries()))

                    # identifier for jec shifted variables
                    for shift in self.shifts:
                        jec_identifier = self.get_jec_identifier(shift)

                        # pt aliases for jets
                        for obj in ["jet1", "jet2", "jet3", "jet4"]:
                            tree.SetAlias(
                                "{0}_pt{1}".format(obj, jec_identifier),
                                "({0}_px{1}**2 + {0}_py{1}**2)**0.5".format(
                                    obj, jec_identifier))
                        # b-tagging alias
                        btag_var = self.config_inst.get_aux("btaggers")[
                            self.b_tagger]["variable"]
                        for obj in ["jet1", "jet2", "jet3", "jet4"]:
                            variable = self.config_inst.get_variable(
                                "{0}_{1}".format(obj, btag_var))
                            tree.SetAlias(
                                variable.name + jec_identifier,
                                variable.expression.format(
                                    **{"jec_identifier": jec_identifier}))
                    # pt aliases for leptons
                    for obj in ["lep1", "lep2"]:
                        tree.SetAlias(
                            "{0}_pt".format(obj),
                            "({0}_px**2 + {0}_py**2)**0.5".format(obj))

                    # extend the tree
                    if self.dataset_inst.is_mc:
                        with self.publish_step(
                                "extending the input tree with weights ..."):
                            weighters = []

                            # pileup weight
                            weighters.append(
                                self.get_pileup_weighter(inp["pu"]))

                            # weights from previous iterations
                            if self.iteration > 0:
                                # b-tagging scale factors
                                for shift in self.shifts:
                                    nominal_sfs = inp["sf"]["nominal"]["sf"] if shift.startswith("c_stat") \
                                        else None
                                    weighters.append(
                                        self.get_scale_factor_weighter(
                                            inp["sf"],
                                            shift,
                                            nominal_sfs=nominal_sfs))

                            input_file.cd()
                            with TreeExtender(tree) as te:
                                for add_branch, _ in weighters:
                                    add_branch(te)
                                for i, entry in enumerate(te):
                                    if (i % 1000) == 0:
                                        print "event {}".format(i)
                                    for _, add_value in weighters:
                                        add_value(entry)

                        # read in total number of events
                        sum_weights = inp["meta"].load(
                        )["event_weights"]["sum"]

                    # get category-dependent binning if optimized binning is used
                    # only for b-taaging discriminants
                    if self.optimize_binning:
                        category_binnings = inp["binning"].load()

                    for i, (channel, category) in enumerate(categories):
                        self.publish_message(
                            "writing histograms in category {} ({}/{})".format(
                                category.name, i + 1, len(categories)))

                        # get the region (HF / LF)
                        # not all child categories have regions associated, e.g. the phase space
                        # inclusive regions ("measure", "closure")
                        region = category.get_aux("region", None)

                        # set weights that are common for all shifts
                        base_weights = []
                        if self.dataset_inst.is_mc:
                            base_weights.append("gen_weight")
                            # lumi weight
                            lumi = self.config_inst.get_aux("lumi")[channel]
                            x_sec = process.get_xsec(
                                self.config_inst.campaign.ecm).nominal
                            lumi_weight = lumi * x_sec / sum_weights
                            base_weights.append(str(lumi_weight))

                            # pu weight
                            base_weights.append("pu_weight")

                        for process in processes:
                            # change into the correct directory
                            process_dirs[(category.name, process.name)].cd()
                            for shift in self.shifts:
                                jec_identifier = self.get_jec_identifier(shift)

                                # weights
                                weights = base_weights[:]
                                if self.dataset_inst.is_mc:
                                    # channel scale weight
                                    if self.iteration > 0:
                                        # b-tag scale factor weights
                                        phase_space = category.get_aux(
                                            "phase_space", None)
                                        # In measurement categories,
                                        # apply scale factors only for contamination
                                        if phase_space == "measure" and not self.final_it:
                                            weights.append(
                                                "scale_factor_c_{}".format(
                                                    shift))
                                            if region == "hf":
                                                weights.append(
                                                    "scale_factor_lf_{}".
                                                    format(shift))
                                            elif region == "lf":
                                                weights.append(
                                                    "scale_factor_hf_{}".
                                                    format(shift))
                                            elif region == "cont":
                                                weights.append(
                                                    "scale_factor_lf_{}".
                                                    format(shift))
                                                weights.append(
                                                    "scale_factor_hf_{}".
                                                    format(shift))
                                            else:
                                                raise ValueError(
                                                    "Unexpected region {}".
                                                    format(region))
                                        else:
                                            weights.append(
                                                "scale_factor_lf_{}".format(
                                                    shift))
                                            weights.append(
                                                "scale_factor_c_{}".format(
                                                    shift))
                                            weights.append(
                                                "scale_factor_hf_{}".format(
                                                    shift))

                                # totalWeight alias
                                while len(weights) < 2:
                                    weights.insert(0, "1")
                                tree.SetAlias(
                                    "totalWeight",
                                    join_root_selection(weights, op="*"))

                                # actual projecting
                                for variable in self.config_inst.variables:
                                    # save variable binning to reset at end of loop
                                    base_variable_binning = variable.binning

                                    if variable.has_tag("skip_all"):
                                        continue
                                    if region and variable.has_tag(
                                            "skip_{}".format(region)):
                                        continue
                                    # if variable tags is given, require at least one
                                    if len(self.variable_tags
                                           ) > 0 and not variable.has_tag(
                                               self.variable_tags, mode=any):
                                        continue
                                    # do not write one b-tag discriminant in the category of another
                                    if variable.get_aux(
                                            "b_tagger",
                                            self.b_tagger) != self.b_tagger:
                                        continue

                                    # if number of bins is specified, overwrite variable binning
                                    if self.binning:
                                        self.binning = list(self.binning)
                                        # if a tuple of (n_bins, x_min, x_max) is given, ensure that n_bins is an integer
                                        if len(self.binning) == 3:
                                            self.binning[0] = int(
                                                self.binning[0])
                                            self.binning = tuple(self.binning)

                                        variable.binning = self.binning

                                    # use optimized binning for b-tag discriminants if provided
                                    if self.optimize_binning and variable.get_aux(
                                            "can_optimize_bins", False):
                                        binning_category = category.get_aux(
                                            "binning_category", category)
                                        # overwrite binning if specialized binning is defined for this category
                                        variable.binning = category_binnings.get(
                                            binning_category.name,
                                            variable.binning)

                                    hist = ROOT.TH1F(
                                        "{}_{}".format(variable.name, shift),
                                        variable.full_title(root=True),
                                        variable.n_bins,
                                        array.array("d", variable.bin_edges))
                                    hist.Sumw2()

                                    # build the full selection string, including the total event weight
                                    selection = [
                                        category.selection,
                                        "jetmet_pass{jec_identifier} == 1",
                                        "{} != -10000".format(
                                            variable.expression),
                                    ]
                                    if variable.selection:
                                        selection.append(variable.selection)
                                    selection = join_root_selection(
                                        selection).format(
                                            **
                                            {"jec_identifier": jec_identifier})
                                    selection = join_root_selection(
                                        selection, "totalWeight", op="*")

                                    # project and write the histogram
                                    tree.Project(
                                        "{}_{}".format(variable.name, shift),
                                        variable.expression.format(
                                            **
                                            {"jec_identifier": jec_identifier
                                             }), selection)
                                    hist.Write()
                                    variable.binning = base_variable_binning

                        progress(i)
コード例 #18
0
class PrepareDockerNetworkForTestEnvironment(StoppableTask):
    logger = logging.getLogger('luigi-interface')

    environment_name = luigi.Parameter()
    network_name = luigi.Parameter()
    test_container_name = luigi.Parameter(significant=False)
    db_container_name = luigi.Parameter(significant=False)
    reuse = luigi.BoolParameter(False, significant=False)
    attempt = luigi.IntParameter(-1)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._client = docker_client_config().get_client()
        self._low_level_client = docker_client_config().get_low_level_client()
        self._prepare_outputs()

    def _prepare_outputs(self):
        self._network_info_target = luigi.LocalTarget(
            "%s/info/environment/%s/network/%s/%s/network_info" %
            (build_config().output_directory, self.environment_name,
             self.network_name, self.attempt))
        if self._network_info_target.exists():
            self._network_info_target.remove()

    def output(self):
        return {DOCKER_NETWORK_INFO: self._network_info_target}

    def run_task(self):
        self.network_info = None
        if self.reuse:
            self.logger.info("Task %s: Try to reuse network %s",
                             self.__repr__(), self.network_name)
            try:
                self.network_info = self.reuse_network()
            except Exception as e:
                self.logger.warning(
                    "Task %s: Tried to reuse network %s, but got Exeception %s. "
                    "Fallback to create new network.", self.__repr__(),
                    self.network_name, e)
        if self.network_info is None:
            self.network_info = self.create_docker_network()
        self.write_output(self.network_info)

    def write_output(self, network_info: DockerNetworkInfo):
        with self.output()[DOCKER_NETWORK_INFO].open("w") as file:
            file.write(network_info.to_json())

    def reuse_network(self) -> DockerNetworkInfo:
        self.remove_container(self.test_container_name)
        return self.get_network_info(reused=True)

    def get_network_info(self, reused: bool):
        network_properties = self._low_level_client.inspect_network(
            self.network_name)
        network_config = network_properties["IPAM"]["Config"][0]
        return DockerNetworkInfo(network_name=self.network_name,
                                 subnet=network_config["Subnet"],
                                 gateway=network_config["Gateway"],
                                 reused=reused)

    def create_docker_network(self) -> DockerNetworkInfo:
        self.remove_container(self.test_container_name)
        self.remove_container(self.db_container_name)
        self.remove_network(self.network_name)
        network = self._client.networks.create(
            name=self.network_name,
            driver="bridge",
        )
        network_info = self.get_network_info(reused=False)
        subnet = network_info.subnet
        gateway = network_info.gateway
        ipam_pool = docker.types.IPAMPool(subnet=subnet, gateway=gateway)
        ipam_config = docker.types.IPAMConfig(pool_configs=[ipam_pool])
        self.remove_network(
            self.network_name)  # TODO race condition possible, add retry
        network = self._client.networks.create(name=self.network_name,
                                               driver="bridge",
                                               ipam=ipam_config)
        return network_info

    def remove_network(self, network_name):
        try:
            self._client.networks.get(network_name).remove()
            self.logger.info("Task %s: Removed network %s", self.__repr__(),
                             network_name)
        except docker.errors.NotFound:
            pass

    def remove_container(self, container_name: str):
        try:
            container = self._client.containers.get(container_name)
            container.remove(force=True)
            self.logger.info("Task %s: Removed container %s", self.__repr__(),
                             container_name)
        except docker.errors.NotFound:
            pass
コード例 #19
0
class HTCondorWorkflow(law.htcondor.HTCondorWorkflow):
    """
    Custom htcondor workflow with good default configs for the CERN batch system.
    """

    poll_interval = luigi.FloatParameter(
        default=0.5,
        significant=False,
        description="time between "
        "status polls in minutes, default: 0.5")
    transfer_logs = luigi.BoolParameter(
        default=True,
        significant=False,
        description="transfer job "
        "logs to the output directory, default: True")
    only_missing = luigi.BoolParameter(
        default=True,
        significant=False,
        description="skip tasks "
        "that are considered complete, default: True")
    max_runtime = law.DurationParameter(
        default=2.0,
        unit="h",
        significant=False,
        description="maximum runtime, default unit is hours, default: 2")
    cmst3 = luigi.BoolParameter(default=False,
                                significant=False,
                                description="use the CMS T3 "
                                "HTCondor quota for jobs, default: False")

    htcondor_skip_cmssw = False

    workflow_run_decorators = [
        law.decorator.notify, law.wlcg.ensure_voms_proxy
    ]

    def htcondor_bootstrap_file(self):
        return os.path.expandvars("$HGC_BASE/files/remote_bootstrap.sh")

    def htcondor_wrapper_file(self):
        return os.path.expandvars(
            "$HGC_BASE/files/cern_htcondor_bash_wrapper.sh")

    def htcondor_output_directory(self):
        return self.local_target(dir=True)

    def htcondor_use_local_scheduler(self):
        return True

    def htcondor_workflow_requires(self):
        reqs = law.htcondor.HTCondorWorkflow.htcondor_workflow_requires(self)

        reqs["repo"] = UploadRepo.req(self, replicas=5)
        reqs["software"] = UploadSoftware.req(self, replicas=5)

        if not self.htcondor_skip_cmssw:
            reqs["cmssw"] = UploadCMSSW.req(self, replicas=5)

        return reqs

    def htcondor_job_config(self, config, job_num, branches):
        reqs = self.htcondor_workflow_requires()

        # add input files
        config.input_files.append(
            law.util.law_src_path("contrib/wlcg/scripts/law_wlcg_tools.sh"))

        # helper to get all possible variants of a directory url
        def uris(req_key):
            uris = reqs[req_key].output().dir.uri(cmd="filecopy",
                                                  return_all=True)
            return ",".join(uris)

        # add render variables
        config.render_variables["hgc_grid_user"] = os.getenv("HGC_GRID_USER")
        config.render_variables["hgc_repo_uri"] = uris("repo")
        config.render_variables["hgc_repo_name"] = os.path.basename(
            reqs["repo"].get_repo_path())
        config.render_variables["hgc_repo_checksum"] = reqs["repo"].checksum
        config.render_variables["hgc_software_uri"] = uris("software")
        config.render_variables["hgc_skip_cmssw"] = str(
            int(self.htcondor_skip_cmssw))

        if not self.htcondor_skip_cmssw:
            config.render_variables["hgc_cmssw_uri"] = uris("cmssw")
            config.render_variables[
                "hgc_cmssw_version"] = self.cmssw_sandbox.env["CMSSW_VERSION"]
            config.render_variables["hgc_cmssw_checksum"] = reqs[
                "cmssw"].checksum
            config.render_variables["hgc_scram_arch"] = self.cmssw_sandbox.env[
                "SCRAM_ARCH"]

        # add X509_USER_PROXY to input files
        user_proxy = law.wlcg.get_voms_proxy_file()
        if user_proxy and os.path.exists(user_proxy):
            config.input_files.append(user_proxy)

        # render variables
        config.render_variables["hgc_remote_type"] = "HTCONDOR"
        config.render_variables["hgc_env_path"] = os.getenv("PATH")

        # custom content
        config.custom_content.append(
            ("requirements", "(OpSysAndVer =?= \"CentOS7\")"))
        config.custom_content.append(("log", "/dev/null"))
        config.custom_content.append(
            ("+MaxRuntime", int(math.floor(self.max_runtime * 3600)) - 1))
        if self.cmst3:
            config.custom_content.append(
                ("+AccountingGroup", "group_u_CMST3.all"))

        return config
コード例 #20
0
class NodeAssignmentTask(luigi.Task):
    path = luigi.Parameter()
    out_key = luigi.Parameter()
    config_path = luigi.Parameter()
    max_jobs = luigi.IntParameter()
    tmp_folder = luigi.Parameter()
    dependency = luigi.TaskParameter()
    # FIXME default does not work; this still needs to be specified
    time_estimate = luigi.IntParameter(default=10)
    run_local = luigi.BoolParameter(default=False)

    def requires(self):
        return self.dependency

    def run(self):
        from .. import util

        # copy the script to the temp folder and replace the shebang
        file_dir = os.path.dirname(os.path.abspath(__file__))
        script_path = os.path.join(self.tmp_folder, 'node_assignment.py')
        util.copy_and_replace(os.path.join(file_dir, 'node_assignment.py'),
                              script_path)

        # find the number of blocks
        with open(self.config_path) as f:
            config = json.load(f)
            block_shape = config['block_shape']
            n_threads = config['n_threads']

        f = z5py.File(self.path)
        ds = f[self.out_key]
        shape = ds.shape
        blocking = nifty.tools.blocking([0, 0, 0], shape, block_shape)
        n_blocks = blocking.numberOfBlocks

        n_jobs = min(n_blocks, self.max_jobs)

        # prepare the job
        config_path = os.path.join(self.tmp_folder,
                                   'node_assignment_config.json')
        with open(config_path, 'w') as f:
            json.dump({'n_threads': n_threads}, f)
        # submit the job
        command = '%s %s %i %s' % (script_path, self.tmp_folder, n_jobs,
                                   config_path)
        log_file = os.path.join(self.tmp_folder, 'logs', 'log_node_assignment')
        err_file = os.path.join(self.tmp_folder, 'error_logs',
                                'err_node_assignment.err')
        bsub_command = 'bsub -n %i -J nde_assignment -We %i -o %s -e %s \'%s\'' % (
            n_threads, self.time_estimate, log_file, err_file, command)
        if self.run_local:
            subprocess.call([command], shell=True)
        else:
            subprocess.call([bsub_command], shell=True)

        # wait till all jobs are finished
        if not self.run_local:
            util.wait_for_jobs('papec')

        # check for correct execution
        out_path = self.output().path
        success = os.path.exists(out_path)
        if not success:
            raise RuntimeError("Compute node assignment failed")

    def output(self):
        return luigi.LocalTarget(
            os.path.join(self.tmp_folder, 'component_assignments.n5',
                         'assignments'))
コード例 #21
0
class HellingerDistance(luigi.Task):
    d = luigi.Parameter()
    tr_bias = luigi.FloatParameter()
    te_bias = luigi.FloatParameter()
    foldidx = luigi.IntParameter()
    outdir = luigi.Parameter()
    tr_frac = luigi.FloatParameter(default=.5)
    random_seed = luigi.IntParameter(default=1234)
    size = luigi.IntParameter(default=1000)

    n_topics = luigi.IntParameter()
    max_gpu = luigi.FloatParameter()
    model_params = luigi.Parameter(default="{}")
    log_params = luigi.Parameter(default="{}")
    fit_params = luigi.Parameter(default="{}")
    fit_on_both = luigi.BoolParameter(default=False)

    # one of TD_DNN, TD_VAE, TD_sProdLDA, TD_LDA
    topic_distrib_task = luigi.Parameter()
    
    def requires(self):
        return [
            DatasetPairs(
                d=self.d,
                tr_bias=self.tr_bias,
                te_bias=self.te_bias,
                foldidx=self.foldidx,
                outdir=self.outdir,
                tr_frac=self.tr_frac,
                random_seed=self.random_seed,
                size=self.size,
            ),
            getattr(td, self.topic_distrib_task)(
                d=self.d,
                tr_bias=self.tr_bias,
                te_bias=self.te_bias,
                foldidx=self.foldidx,
                outdir=self.outdir,
                tr_frac=self.tr_frac,
                random_seed=self.random_seed,
                size=self.size,
                n_topics=self.n_topics,
                max_gpu=self.max_gpu,
                model_params=self.model_params,
                log_params=self.log_params,
                fit_params=self.fit_params,
                fit_on_both=self.fit_on_both,
            ),
        ]

    def run(self):
        (tr_path, te_path), (tr_topic_path, te_topic_path) = self.input()

        # load required data
        req_data = []
        for path in [tr_path, te_path, tr_topic_path, te_topic_path]:
            with path.open("r") as fd:
                req_data.append(pickle.load(fd))
        d_tr, d_te, topic_tr, topic_te = req_data
        
        pxgz_delta = np.zeros(self.n_topics)
        topic_tr, topic_te = [_.argmax(axis=1) for _ in [topic_tr, topic_te]]
        
        for zi in range(self.n_topics):
            topic_tri = (topic_tr == zi).astype(int)
            topic_tei = (topic_te == zi).astype(int)

            if not np.any(topic_tri) and not(np.any(topic_tei)):
                pxgz_delta[zi] = np.nan
            else:
                pxgz_delta[zi] = pxgz_diff_hd(
                    d_tr.X, d_tr.y, topic_tri,
                    d_te.X, d_tr.y, topic_tei,
                    2
                )
            
        r = {}
        r["pxgz_delta"] = pxgz_delta.tolist()
        r["corr_tr"] = d_tr.pearsonr[0]
        r["corr_te"] = d_te.pearsonr[0]
        r["bias_tr"] = d_tr.get_bias()
        r["bias_te"] = d_te.get_bias()
        
        with self.output().open("w") as fd:
            fd.write(json.dumps(r))
        
    def output(self):
        fname = "trbias={:.3f}_tebias={:.3f}_size={}_foldidx={}_trfrac={:.3f}.json".format(
            self.tr_bias, self.te_bias, self.size, self.foldidx, self.tr_frac
        )
        fpath = os.path.join(
            self.outdir, "pxgz_diff", self.topic_distrib_task, fname
        )
        return luigi.LocalTarget(fpath)
コード例 #22
0
class TestContainer(FlavorTask):
    release_types = luigi.ListParameter(["Release"])
    generic_language_tests = luigi.ListParameter([])
    test_folders = luigi.ListParameter([])
    test_files = luigi.ListParameter([])
    test_restrictions = luigi.ListParameter([])
    languages = luigi.ListParameter([None])
    test_environment_vars = luigi.DictParameter({"TRAVIS": ""},
                                                significant=False)

    test_log_level = luigi.Parameter("critical", significant=False)
    reuse_database = luigi.BoolParameter(False, significant=False)
    reuse_uploaded_container = luigi.BoolParameter(False, significant=False)
    reuse_database_setup = luigi.BoolParameter(False, significant=False)
    reuse_test_container = luigi.BoolParameter(False, significant=False)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._prepare_outputs()
        stoppable_task = StoppableTask()
        if stoppable_task.failed_target.exists():
            print("removed failed target")
            stoppable_task.failed_target.remove()
        self.actual_release_types = [
            ReleaseType[release_type] for release_type in self.release_types
        ]

    def requires_tasks(self):
        return [
            self.generate_tasks_for_flavor(flavor_path, release_type)
            for flavor_path in self.actual_flavor_paths
            for release_type in self.actual_release_types
        ]

    def generate_tasks_for_flavor(self, flavor_path,
                                  release_type: ReleaseType):
        args = dict(flavor_path=flavor_path,
                    reuse_database=self.reuse_database,
                    reuse_uploaded_container=self.reuse_uploaded_container,
                    reuse_database_setup=self.reuse_database_setup,
                    reuse_test_container=self.reuse_test_container,
                    generic_language_tests=self.generic_language_tests,
                    test_folders=self.test_folders,
                    test_restrictions=self.test_restrictions,
                    log_level=self.test_log_level,
                    test_environment_vars=self.test_environment_vars,
                    languages=self.languages,
                    test_files=self.test_files,
                    release_type=release_type.name)
        return TestRunnerDBTestTask(**args)

    def _prepare_outputs(self):
        self._target = luigi.LocalTarget(
            "%s/logs/test-runner/db-test/tests/current" %
            (build_config().output_directory))
        if self._target.exists():
            self._target.remove()

    def output(self):
        return self._target

    def run(self):
        with self.output().open("w") as out_file:
            for release in self.input():
                # for in_target in releases:
                with release.open("r") as in_file:
                    out_file.write(in_file.read())
                    out_file.write("\n")
                    out_file.write(
                        "=================================================")
                    out_file.write("\n")
コード例 #23
0
class _FitModel(AutoLocalOutputMixin(base_path=LUIGI_COMPLETED_TARGETS_DIR),
                LoadInputDictMixin, DeleteDepsRecursively, luigi.Task, ABC):
    dataset_settings = luigi.DictParameter()
    samples_per_class = luigi.IntParameter()
    random_seed = luigi.Parameter()

    in_memory = luigi.BoolParameter(default=False)

    @abstractmethod
    def fit_model(self, negative_samples, positive_samples):
        """
        Given positive and negative samples return a fitted model
        Parameters
        """
        pass

    @classmethod
    def compute_classification_accuracy(cls, model, *samples):
        """
        Parameters
        ----------
        model : dict
            All the data that represents a fitted model
        samples : list of {'X':array, 'y':array}
            The samples on which we should compute statistical distance
        Returns
        -------
        float : the statistical distance

        """
        raise NotImplementedError()

    def requires(self):
        req = {}
        req['samples_positive'] = self.gen_samples_type(
            dataset_settings=self.dataset_settings,
            num_samples=self.samples_per_class,
            random_seed='training_positive_size{}_seed{}'.format(
                self.samples_per_class, self.random_seed),
            generate_positive_samples=True,
        )
        req['samples_negative'] = self.gen_samples_type(
            dataset_settings=self.dataset_settings,
            num_samples=self.samples_per_class,
            random_seed='training_negative_size{}_seed{}'.format(
                self.samples_per_class, self.random_seed),
            generate_positive_samples=False,
        )
        self.reqs_ = req
        if not self.in_memory:
            return req
        return {}

    def run(self):
        _input = self.compute_or_load_requirements()

        # We set the random seed since the model's fitter may use
        # numbers from the random stream.
        GenSample.set_simple_random_seed(sample_number=-1,
                                         random_seed=self.random_seed)
        model = self.fit_model(_input['samples_negative'],
                               _input['samples_positive'])
        self.output_ = model
        with self.output().open('wb') as f:
            dill.dump(model, f, 2)
コード例 #24
0
class SegmentationWorkflowBase(WorkflowBase):
    input_path = luigi.Parameter()
    input_key = luigi.Parameter()
    # where to save the watersheds
    ws_path = luigi.Parameter()
    ws_key = luigi.Parameter()
    # where to save the problem (graph, edge_features etc)
    problem_path = luigi.Parameter()
    # where to save the node labels
    node_labels_key = luigi.Parameter()
    # where to save the resulting segmentation
    output_path = luigi.Parameter()
    output_key = luigi.Parameter()
    # optional path to mask
    mask_path = luigi.Parameter(default='')
    mask_key = luigi.Parameter(default='')

    # optional path for random forest used for cost computation
    rf_path = luigi.Parameter(default='')
    # node label dict: dictionary for additional node labels used in costs
    node_label_dict = luigi.DictParameter(default={})

    # number of jobs used for merge tasks
    max_jobs_merge = luigi.IntParameter(default=1)
    # skip watershed (watershed volume must already be preset)
    skip_ws = luigi.BoolParameter(default=False)
    # run agglomeration immediately after watersheds
    agglomerate_ws = luigi.BoolParameter(default=False)
    # run two-pass watershed
    two_pass_ws = luigi.BoolParameter(default=False)
    # run some sanity checks for intermediate results
    sanity_checks = luigi.BoolParameter(default=False)

    # hard-coded keys
    graph_key = 's0/graph'
    features_key = 'features'
    costs_key = 's0/costs'

    def _watershed_tasks(self):
        if self.skip_ws:
            assert os.path.exists(os.path.join(
                self.ws_path,
                self.ws_key)), "%s:%s" % (self.ws_path, self.ws_key)
            return self.dependency
        else:
            dep = WatershedWorkflow(tmp_folder=self.tmp_folder,
                                    max_jobs=self.max_jobs,
                                    config_dir=self.config_dir,
                                    target=self.target,
                                    dependency=self.dependency,
                                    input_path=self.input_path,
                                    input_key=self.input_key,
                                    output_path=self.ws_path,
                                    output_key=self.ws_key,
                                    mask_path=self.mask_path,
                                    mask_key=self.mask_key,
                                    two_pass=self.two_pass_ws,
                                    agglomeration=self.agglomerate_ws)
            return dep

    def _problem_tasks(self, dep, compute_costs):
        dep = ProblemWorkflow(tmp_folder=self.tmp_folder,
                              config_dir=self.config_dir,
                              max_jobs=self.max_jobs,
                              target=self.target,
                              dependency=dep,
                              input_path=self.input_path,
                              input_key=self.input_key,
                              ws_path=self.ws_path,
                              ws_key=self.ws_key,
                              problem_path=self.problem_path,
                              rf_path=self.rf_path,
                              node_label_dict=self.node_label_dict,
                              max_jobs_merge=self.max_jobs_merge,
                              compute_costs=compute_costs,
                              sanity_checks=self.sanity_checks)
        return dep

    def _write_tasks(self, dep, identifier):
        if self.output_key == '':
            return dep
        write_task = getattr(write_tasks, self._get_task_name('Write'))
        dep = write_task(tmp_folder=self.tmp_folder,
                         max_jobs=self.max_jobs,
                         config_dir=self.config_dir,
                         dependency=dep,
                         input_path=self.ws_path,
                         input_key=self.ws_key,
                         output_path=self.output_path,
                         output_key=self.output_key,
                         assignment_path=self.output_path,
                         assignment_key=self.node_labels_key,
                         identifier=identifier)
        return dep

    @staticmethod
    def get_config():
        config = {
            **WatershedWorkflow.get_config(),
            **ProblemWorkflow.get_config()
        }
        return config
コード例 #25
0
class GenerateSharesTask(tasks.PuppetTask, manifest_tasks.ManifestMixen):
    puppet_account_id = luigi.Parameter()
    manifest_file_path = luigi.Parameter()
    should_use_sns = luigi.BoolParameter()
    section = luigi.Parameter()
    cache_invalidator = luigi.Parameter()

    def params_for_results_display(self):
        return {
            "puppet_account_id": self.puppet_account_id,
            "manifest_file_path": self.manifest_file_path,
            "section": self.section,
            "cache_invalidator": self.cache_invalidator,
        }

    def requires(self):
        portfolios = dict()
        requirements = dict(
            deletes=list(),
            ensure_event_buses=list(),
            generate_policies=list(),
            portfolios=portfolios,
        )
        for region_name, accounts in self.manifest.get_accounts_by_region(
        ).items():
            requirements["deletes"].append(
                general_tasks.DeleteCloudFormationStackTask(
                    account_id=self.puppet_account_id,
                    region=region_name,
                    stack_name="servicecatalog-puppet-shares",
                ))

        for (
                region_name,
                sharing_policies,
        ) in self.manifest.get_sharing_policies_by_region().items():
            requirements["ensure_event_buses"].append(
                EnsureEventBridgeEventBusTask(
                    puppet_account_id=self.puppet_account_id,
                    region=region_name,
                ))

            requirements["generate_policies"].append(
                GeneratePolicies(
                    puppet_account_id=self.puppet_account_id,
                    manifest_file_path=self.manifest_file_path,
                    region=region_name,
                    sharing_policies=sharing_policies,
                    should_use_sns=self.should_use_sns,
                    cache_invalidator=self.cache_invalidator,
                ))

        for (
                region_name,
                shares_by_portfolio_account,
        ) in self.manifest.get_shares_by_region_portfolio_account(
                self.puppet_account_id, self.section).items():
            for (
                    portfolio_name,
                    shares_by_account,
            ) in shares_by_portfolio_account.items():
                for account_id, share in shares_by_account.items():
                    i = "_".join([
                        str(self.puppet_account_id),
                        portfolio_name,
                        str(account_id),
                        region_name,
                    ])
                    portfolios[
                        i] = portfoliomanagement_tasks.GetPortfolioByPortfolioName(
                            manifest_file_path=self.manifest_file_path,
                            puppet_account_id=self.puppet_account_id,
                            portfolio=portfolio_name,
                            account_id=self.puppet_account_id,
                            region=region_name,
                            cache_invalidator=self.cache_invalidator,
                        )
        return requirements

    def run(self):
        tasks = list()
        for (
                region_name,
                shares_by_portfolio_account,
        ) in self.manifest.get_shares_by_region_portfolio_account(
                self.puppet_account_id, self.section).items():
            for (
                    portfolio_name,
                    shares_by_account,
            ) in shares_by_portfolio_account.items():
                for account_id, share in shares_by_account.items():
                    i = "_".join([
                        str(self.puppet_account_id),
                        portfolio_name,
                        str(account_id),
                        region_name,
                    ])
                    portfolio_input = self.input().get("portfolios").get(i)

                    if portfolio_input is None:
                        raise Exception(
                            f"failed to get portfolios details for {i} in {self.input().get('portfolios')}"
                        )

                    portfolio = json.loads(portfolio_input.open("r").read())

                    tasks.append(
                        portfoliomanagement_tasks.
                        CreateShareForAccountLaunchRegion(
                            manifest_file_path=self.manifest_file_path,
                            puppet_account_id=self.puppet_account_id,
                            account_id=account_id,
                            region=region_name,
                            portfolio=portfolio_name,
                            portfolio_id=portfolio.get("portfolio_id"),
                            sharing_mode=share.get(self.section).get(
                                "sharing_mode",
                                config.get_global_sharing_mode_default(
                                    self.puppet_account_id),
                            ),
                        ))

        yield tasks
        self.write_output(self.params_for_results_display())
コード例 #26
0
class ProblemWorkflow(WorkflowBase):
    input_path = luigi.Parameter()
    input_key = luigi.Parameter()
    ws_path = luigi.Parameter()
    ws_key = luigi.Parameter()
    problem_path = luigi.Parameter()

    # optional params for costs
    rf_path = luigi.Parameter(default='')
    node_label_dict = luigi.DictParameter(default={})

    max_jobs_merge = luigi.IntParameter(default=1)
    # do we compte costs
    compute_costs = luigi.BoolParameter(default=True)
    # do we run sanity checks ?
    sanity_checks = luigi.BoolParameter(default=False)

    # hard-coded keys
    graph_key = 's0/graph'
    features_key = 'features'
    costs_key = 's0/costs'

    def requires(self):
        dep = GraphWorkflow(tmp_folder=self.tmp_folder,
                            max_jobs=self.max_jobs,
                            config_dir=self.config_dir,
                            target=self.target,
                            dependency=self.dependency,
                            input_path=self.ws_path,
                            input_key=self.ws_key,
                            graph_path=self.problem_path,
                            output_key=self.graph_key,
                            n_scales=1)
        # sanity check the subgraph
        if self.sanity_checks:
            graph_block_prefix = os.path.join(self.problem_path, 's0',
                                              'sub_graphs', 'block_')
            dep = CheckSubGraphsWorkflow(tmp_folder=self.tmp_folder,
                                         max_jobs=self.max_jobs,
                                         config_dir=self.config_dir,
                                         target=self.target,
                                         ws_path=self.ws_path,
                                         ws_key=self.ws_key,
                                         graph_block_prefix=graph_block_prefix,
                                         dependency=dep)
        dep = EdgeFeaturesWorkflow(tmp_folder=self.tmp_folder,
                                   max_jobs=self.max_jobs,
                                   config_dir=self.config_dir,
                                   target=self.target,
                                   dependency=dep,
                                   input_path=self.input_path,
                                   input_key=self.input_key,
                                   labels_path=self.ws_path,
                                   labels_key=self.ws_key,
                                   graph_path=self.problem_path,
                                   graph_key=self.graph_key,
                                   output_path=self.problem_path,
                                   output_key=self.features_key,
                                   max_jobs_merge=self.max_jobs_merge)
        if self.compute_costs:
            dep = EdgeCostsWorkflow(tmp_folder=self.tmp_folder,
                                    max_jobs=self.max_jobs,
                                    config_dir=self.config_dir,
                                    target=self.target,
                                    dependency=dep,
                                    features_path=self.problem_path,
                                    features_key=self.features_key,
                                    output_path=self.problem_path,
                                    output_key=self.costs_key,
                                    node_label_dict=self.node_label_dict,
                                    seg_path=self.ws_path,
                                    seg_key=self.ws_key,
                                    rf_path=self.rf_path)
        return dep

    @staticmethod
    def get_config():
        config = {
            **GraphWorkflow.get_config(),
            **EdgeFeaturesWorkflow.get_config(),
            **EdgeCostsWorkflow.get_config()
        }
        return config
コード例 #27
0
class WithDefaultTrue(luigi.Task):
    x = luigi.BoolParameter(default=True)
コード例 #28
0
class ConversionWorkflow(WorkflowBase):
    path = luigi.Parameter()
    raw_key = luigi.Parameter()
    label_in_key = luigi.Parameter()
    label_out_key = luigi.Parameter()
    label_scale = luigi.IntParameter()
    assignment_path = luigi.Parameter(default='')
    assignment_key = luigi.Parameter(default='')
    use_label_multiset = luigi.BoolParameter(default=False)
    offset = luigi.ListParameter(default=[0, 0, 0])
    resolution = luigi.ListParameter(default=[1, 1, 1])

    #####################################
    # Step 1 Implementations: make_labels
    #####################################

    def _link_labels(self, data_path, dependency):
        norm_path = os.path.abspath(os.path.realpath(self.path))
        src = os.path.join(norm_path, self.label_in_key)
        dst = os.path.join(data_path, 's0')
        # self._write_log("linking label dataset from %s to %s" % (src, dst))
        os.symlink(src, dst)
        return dependency

    # TODO implement
    def _make_label_multiset(self):
        raise NotImplementedError("Label multi-set not implemented yet")

    def _make_labels(self, dependency):

        # check if we have output labels already
        dst_key = os.path.join(self.label_out_key, 'data', 's0')
        with z5py.File(self.path) as f:
            assert self.label_in_key in f, "key %s not in input file" % self.label_in_key
            if dst_key in f:
                return dependency

        # we make the label output group
        with z5py.File(self.path) as f:
            g = f.require_group(self.label_out_key)
            dgroup = g.require_group('data')
            # resolve relative paths and links
            data_path = os.path.abspath(os.path.realpath(dgroup.path))

        # if we use label-multisets, we need to create the label multiset for this scale
        # otherwise, we just make a symlink
        # make symlink from input dataset to output dataset
        return self._make_label_multiset(dependency) if self.use_label_multiset\
            else self._link_labels(data_path, dependency)

    ######################################
    # Step 2 Implementations: align scales
    ######################################

    # TODO implement for label-multi-set
    def _downsample_labels(self, downsample_scales, scale_factors, dependency):
        task = getattr(sampling_tasks, self._get_task_name('Downscaling'))

        # run downsampling
        in_key = os.path.join(self.label_out_key, 'data', 's0')
        dep = dependency

        effective_scale = [1, 1, 1]
        label_scales = range(1, len(downsample_scales) + 1)
        for scale, out_scale in zip(label_scales, downsample_scales):
            out_key = os.path.join(self.label_out_key, 'data', 's%i' % scale)
            scale_factor = scale_factors[out_scale]
            effective_scale = [
                eff * scf for eff, scf in zip(effective_scale, scale_factor)
            ]
            dep = task(tmp_folder=self.tmp_folder,
                       max_jobs=self.max_jobs,
                       config_dir=self.config_dir,
                       input_path=self.path,
                       input_key=in_key,
                       output_path=self.path,
                       output_key=out_key,
                       scale_factor=scale_factor,
                       scale_prefix='s%i' % scale,
                       effective_scale_factor=effective_scale,
                       dependency=dep)

            in_key = out_key
        return dep

    def _align_scales(self, dependency):
        # check which sales we have in the raw data
        raw_dir = os.path.join(self.path, self.raw_key)
        raw_scales = os.listdir(raw_dir)
        raw_scales = [
            rscale for rscale in raw_scales
            if os.path.isdir(os.path.join(raw_dir, rscale))
        ]

        def isint(inp):
            try:
                int(inp)
                return True
            except ValueError:
                return False

        raw_scales = np.array(
            [int(rscale[1:]) for rscale in raw_scales if isint(rscale[1:])])
        raw_scales = np.sort(raw_scales)

        # match the label scale and determine which scales we have to compute
        # via downsampling
        downsample_scales = raw_scales[self.label_scale + 1:]

        # load the scale factors from the raw dataset
        scale_factors = []
        relative_scale_factors = []
        with z5py.File(self.path) as f:
            for scale in raw_scales:
                scale_key = os.path.join(self.raw_key, 's%i' % scale)
                # we need to reverse the scale factors because paintera has axis order
                # XYZ and we have axis order ZYX
                if scale == 0:
                    scale_factors.append([1, 1, 1])
                    relative_scale_factors.append([1, 1, 1])
                else:
                    scale_factor = f[scale_key].attrs[
                        'downsamplingFactors'][::-1]
                    # find the relative scale factor
                    rel_scale = [
                        int(sf_out // sf_in) for sf_out, sf_in in zip(
                            scale_factor, scale_factors[-1])
                    ]

                    scale_factors.append(scale_factor)
                    relative_scale_factors.append(rel_scale)

        # downsample segmentations
        t_down = self._downsample_labels(downsample_scales,
                                         relative_scale_factors, dependency)
        return t_down, relative_scale_factors

    ############################################
    # Step 3 Implementations: make block uniques
    ############################################

    def _uniques_in_blocks(self, dependency, scale_factors):
        task = getattr(unique_tasks, self._get_task_name('UniqueBlockLabels'))
        # require the unique-labels group
        with z5py.File(self.path) as f:
            f.require_group(os.path.join(self.label_out_key, 'unique-labels'))
        dep = dependency

        effective_scale = [1, 1, 1]
        for scale, factor in enumerate(scale_factors):
            in_key = os.path.join(self.label_out_key, 'data', 's%i' % scale)
            out_key = os.path.join(self.label_out_key, 'unique-labels',
                                   's%i' % scale)
            effective_scale = [
                eff * sf for eff, sf in zip(effective_scale, factor)
            ]
            dep = task(tmp_folder=self.tmp_folder,
                       max_jobs=self.max_jobs,
                       config_dir=self.config_dir,
                       input_path=self.path,
                       output_path=self.path,
                       input_key=in_key,
                       output_key=out_key,
                       effective_scale_factor=effective_scale,
                       dependency=dep,
                       prefix='s%i' % scale)
        return dep

    ##############################################
    # Step 4 Implementations: invert block uniques
    ##############################################

    def _label_block_mapping(self, dependency, scale_factors):
        task = getattr(labels_to_block_tasks,
                       self._get_task_name('LabelBlockMapping'))
        # require the labels-to-blocks group
        with z5py.File(self.path) as f:
            f.require_group(
                os.path.join(self.label_out_key, 'label-to-block-mapping'))
        # get the framgent max id
        with z5py.File(self.path) as f:
            max_id = f[self.label_in_key].attrs['maxId']

        # compte the label to block mapping for all scales
        n_scales = len(scale_factors)
        dep = dependency
        for scale in range(n_scales):
            in_key = os.path.join(self.label_out_key, 'unique-labels',
                                  's%i' % scale)
            out_key = os.path.join(self.label_out_key,
                                   'label-to-block-mapping', 's%i' % scale)
            dep = task(tmp_folder=self.tmp_folder,
                       max_jobs=self.max_jobs,
                       config_dir=self.config_dir,
                       input_path=self.path,
                       output_path=self.path,
                       input_key=in_key,
                       output_key=out_key,
                       number_of_labels=max_id + 1,
                       dependency=dep,
                       prefix='s%i' % scale)
        return dep

    #####################################################
    # Step 5 Implementations: fragment segment assignment
    #####################################################

    def _fragment_segment_assignment(self, dependency):
        if self.assignment_path == '':
            # get the framgent max id
            with z5py.File(self.path) as f:
                max_id = f[self.label_in_key].attrs['maxId']
            return dependency, max_id
        else:
            assert self.assignment_key != ''
            assert os.path.exists(self.assignment_path), self.assignment_path
            # TODO should make this a task
            with z5py.File(self.assignment_path) as f, z5py.File(
                    self.path) as f_out:
                assignments = f[self.assignment_key][:]
                n_fragments = len(assignments)

                # find the fragments which have non-trivial assignment
                segment_ids, counts = np.unique(assignments,
                                                return_counts=True)
                seg_ids_to_counts = {
                    seg_id: count
                    for seg_id, count in zip(segment_ids, counts)
                }
                fragment_ids_to_counts = nt.takeDict(seg_ids_to_counts,
                                                     assignments)
                fragment_ids = np.arange(n_fragments, dtype='uint64')

                non_triv_fragments = fragment_ids[fragment_ids_to_counts > 1]
                non_triv_segments = assignments[non_triv_fragments]
                non_triv_segments += n_fragments

                # determine the overall max id
                max_id = int(non_triv_segments.max())

                # TODO do we need to assign a special value to ignore label (0) ?
                frag_to_seg = np.vstack(
                    (non_triv_fragments, non_triv_segments))

                # fragment_ids = np.arange(n_fragments, dtype='uint64')
                # assignments += n_fragments
                # frag_to_seg = np.vstack((fragment_ids, assignments))

                # max_id = int(frag_to_seg.max())

                out_key = os.path.join(self.label_out_key,
                                       'fragment-segment-assignment')
                chunks = (1, frag_to_seg.shape[1])
                f_out.require_dataset(out_key,
                                      data=frag_to_seg,
                                      shape=frag_to_seg.shape,
                                      compression='gzip',
                                      chunks=chunks)
            return dependency, max_id

    def requires(self):
        # first, we make the labels at label_out_key
        # (as label-multi-set if specified)
        dep = self._make_labels(self.dependency)
        # next, align the scales of labels and raw data
        dep, scale_factors = self._align_scales(dep)
        downsampling_factors = [[1, 1, 1]
                                ] + scale_factors[self.label_scale + 1:]

        # # next, compute the mapping of unique labels to blocks
        dep = self._uniques_in_blocks(dep, downsampling_factors)
        # # next, compute the inverse mapping
        dep = self._label_block_mapping(dep, downsampling_factors)
        # # next, compute the fragment-segment-assignment
        dep, max_id = self._fragment_segment_assignment(dep)

        # finally, write metadata
        dep = WritePainteraMetadata(tmp_folder=self.tmp_folder,
                                    path=self.path,
                                    raw_key=self.raw_key,
                                    label_group=self.label_out_key,
                                    scale_factors=scale_factors,
                                    label_scale=self.label_scale,
                                    is_label_multiset=self.use_label_multiset,
                                    resolution=self.resolution,
                                    offset=self.offset,
                                    max_id=max_id,
                                    dependency=dep)
        return dep

    @staticmethod
    def get_config():
        configs = super(ConversionWorkflow, ConversionWorkflow).get_config()
        configs.update({
            'unique_block_labels':
            unique_tasks.UniqueBlockLabelsLocal.default_task_config(),
            'label_block_mapping':
            labels_to_block_tasks.LabelBlockMappingLocal.default_task_config(),
            'downscaling':
            sampling_tasks.DownscalingLocal.default_task_config()
        })
        return configs
コード例 #29
0
 def testBool(self):
     p = luigi.BoolParameter(config_path=dict(section="foo", name="bar"))
     self.assertEqual(True, _value(p))
class MergeSubgraphScalesTask(luigi.Task):
    """
    Merge subgraphs on scale level
    """

    path = luigi.Parameter()
    ws_key = luigi.Parameter()
    out_path = luigi.Parameter()
    scale = luigi.IntParameter()
    max_jobs = luigi.IntParameter()
    config_path = luigi.Parameter()
    tmp_folder = luigi.Parameter()
    dependency = luigi.TaskParameter()
    # FIXME default does not work; this still needs to be specified
    time_estimate = luigi.IntParameter(default=10)
    run_local = luigi.BoolParameter(default=False)

    def requires(self):
        return self.dependency

    def _prepare_jobs(self, n_jobs, block_list, block_shape):
        for job_id in range(n_jobs):
            block_jobs = block_list[job_id::n_jobs]
            job_config = {'block_shape': block_shape, 'block_list': block_jobs}
            config_path = os.path.join(
                self.tmp_folder,
                'graph_scale%i_config_job%i.json' % (self.scale, job_id))
            with open(config_path, 'w') as f:
                json.dump(job_config, f)

    def _submit_job(self, job_id):
        script_path = os.path.join(self.tmp_folder, 'merge_graph_scales.py')
        config_path = os.path.join(
            self.tmp_folder,
            'graph_scale%i_config_job%i.json' % (self.scale, job_id))
        command = '%s %s %i %i %s %s' % (script_path, self.out_path,
                                         self.scale, job_id, config_path,
                                         self.tmp_folder)
        log_file = os.path.join(self.tmp_folder, 'logs',
                                'log_graph%i_scale_%i' % (self.scale, job_id))
        err_file = os.path.join(
            self.tmp_folder, 'error_logs',
            'err_graph_scale%i_%i.err' % (self.scale, job_id))
        bsub_command = 'bsub -J graph_scale_%i -We %i -o %s -e %s \'%s\'' % (
            job_id, self.time_estimate, log_file, err_file, command)
        if self.run_local:
            subprocess.call([command], shell=True)
        else:
            subprocess.call([bsub_command], shell=True)

    def _collect_outputs(self, block_list):
        times = []
        processed_blocks = []
        for block_id in block_list:
            res_file = os.path.join(
                self.tmp_folder,
                'graph_scale%i_block%i.json' % (self.scale, block_id))
            try:
                with open(res_file) as f:
                    res = json.load(f)
                times.append(res['t'])
                processed_blocks.append(block_id)
                os.remove(res_file)
            except Exception:
                continue
        return processed_blocks, times

    def run(self):
        from production import util

        # copy the script to the temp folder and replace the shebang
        file_dir = os.path.dirname(os.path.abspath(__file__))
        util.copy_and_replace(
            os.path.join(file_dir, 'merge_graph_scales.py'),
            os.path.join(self.tmp_folder, 'merge_graph_scales.py'))

        with open(self.config_path) as f:
            config = json.load(f)
            init_block_shape = config['block_shape']
            roi = config.get('roi', None)

        block_shape = [bs * 2**self.scale for bs in init_block_shape]

        # get the shape and blocking
        ws = z5py.File(self.path)[self.ws_key]
        shape = ws.shape
        f_graph = z5py.File(self.out_path, use_zarr_format=False)
        f_graph.attrs['shape'] = shape

        blocking = nifty.tools.blocking([0, 0, 0], shape, block_shape)
        # check if we have a ROI and adapt the block list if we do
        if roi is None:
            n_blocks = blocking.numberOfBlocks
            block_list = list(range(n_blocks))
        else:
            block_list = blocking.getBlockIdsOverlappingBoundingBox(
                roi[0], roi[1], [0, 0, 0]).tolist()
            n_blocks = len(block_list)

        # find the actual number of jobs and prepare job configs
        n_jobs = min(n_blocks, self.max_jobs)
        self._prepare_jobs(n_jobs, block_list, init_block_shape)

        # submit the jobs
        if self.run_local:
            # this only works in python 3 ?!
            with futures.ProcessPoolExecutor(n_jobs) as tp:
                tasks = [
                    tp.submit(self._submit_job, job_id)
                    for job_id in range(n_jobs)
                ]
                [t.result() for t in tasks]
        else:
            for job_id in range(n_jobs):
                self._submit_job(job_id)

        # wait till all jobs are finished
        if not self.run_local:
            util.wait_for_jobs('papec')

        # check the job outputs
        processed_blocks, times = self._collect_outputs(block_list)
        assert len(processed_blocks) == len(times)
        success = len(processed_blocks) == n_blocks

        # write output file if we succeed, otherwise write partial
        # success to different file and raise exception
        if success:
            out = self.output()
            # TODO does 'out' support with block?
            fres = out.open('w')
            json.dump({'times': times}, fres)
            fres.close()
        else:
            log_path = os.path.join(
                self.tmp_folder,
                'merge_graph_scale%i_partial.json' % self.scale)
            with open(log_path, 'w') as out:
                json.dump(
                    {
                        'times': times,
                        'processed_blocks': processed_blocks
                    }, out)
            raise RuntimeError(
                "MergeGraphScalesTask failed, %i / %i blocks processed, " %
                (len(processed_blocks), n_blocks) +
                "serialized partial results to %s" % log_path)

    def output(self):
        return luigi.LocalTarget(
            os.path.join(self.tmp_folder,
                         'merge_graph_scale%i.log' % self.scale))