Exemplo n.º 1
0
    def __init__(self,
                 code_dir,
                 out_dir,
                 version_id,
                 cause_ids,
                 gbd_round_id=GBD.GBD_ROUND_ID):
        self.code_dir = code_dir
        self.out_dir = out_dir
        self.version_id = version_id
        self.cause_ids = cause_ids
        self.gbd_round_id = gbd_round_id

        self.task_dag = TaskDag(
            name='imported_cases_v{}'.format(self.version_id))
        self.stderr = 'FILEPATH'
        self.stdout = 'FILEPATH'
Exemplo n.º 2
0
class ImportedCasesJobSwarm(object):
    ADDITIONAL_RESTRICTIONS = {562: 'mental_drug_opioids'}
    """This class creates and submits the imported cases task dag."""
    def __init__(self,
                 code_dir,
                 out_dir,
                 version_id,
                 cause_ids,
                 gbd_round_id=GBD.GBD_ROUND_ID):
        self.code_dir = code_dir
        self.out_dir = out_dir
        self.version_id = version_id
        self.cause_ids = cause_ids
        self.gbd_round_id = gbd_round_id

        self.task_dag = TaskDag(
            name='imported_cases_v{}'.format(self.version_id))
        self.stderr = 'FILEPATH'
        self.stdout = 'FILEPATH'

    def create_imported_cases_jobs(self):
        """Generates the tasks and adds them to the task_dag."""
        slots = 38
        memory = slots * 2
        for cause in self.cause_ids:
            task = PythonTask(
                script=os.path.join(self.code_dir, 'imported_cases.py'),
                args=[
                    self.version_id, '--cause_id', cause, '--gbd_round_id',
                    self.gbd_round_id, '--output_dir', self.out_dir
                ],
                name='imported_cases_{}_{}'.format(self.version_id, cause),
                slots=slots,
                mem_free=memory,
                max_attempts=3,
                tag='imported_cases')
            self.task_dag.add_task(task)

    def run(self):
        wf = Workflow(self.task_dag,
                      'imported_cases_v{}'.format(self.version_id),
                      stderr=self.stderr,
                      stdout=self.stdout,
                      project='proj_codcorrect')
        success = wf.run()
        return success
Exemplo n.º 3
0
    def __init__(self,
                 code_dir,
                 version_id,
                 year_ids,
                 start_years,
                 location_set_ids,
                 databases,
                 db_env='dev'):
        """
        Arguments:
            code_dir (str): The directory containing CoDCorrect's code base.
            version_id (int): cod.output_version.output_version_id
            year_ids (int[]): set of years to run.
            start_years (int[]):
            location_set_ids (int[]):
            measure_ids (int[]):
            sex_ids (int[]):
            databases (str[]):
            db_env (str):
        """
        self.code_dir = code_dir
        self.version_id = version_id
        self.year_ids = year_ids
        self.start_years = start_years
        self.location_set_ids = location_set_ids
        self.databases = databases
        self.db_env = db_env

        self.most_detailed_locations = self.get_most_detailed_location_ids()
        self.all_locations = self.get_all_location_ids()

        self.measure_ids = [1, 4]
        self.sex_ids = [1, 2]
        self.pct_change = [True, False]

        self.task_dag = TaskDag(
            name=('CoDCorrect_v{}'.format(self.version_id)))
        self.shock_jobs_by_command = {}
        self.correct_jobs_by_command = {}
        self.agg_cause_jobs_by_command = {}
        self.ylls_jobs_by_command = {}
        self.agg_loc_jobs_by_command = {}
        self.append_shock_jobs_by_command = {}
        self.append_diag_jobs_by_command = {}
        self.summarize_jobs_by_command = {}
        self.upload_jobs_by_command = {}
Exemplo n.º 4
0
    def __init__(self, como_version, redis_server=None):
        self.como_version = como_version

        # instantiate our factories
        self._task_registry = {}
        self._incidence_task_fac = IncidenceTaskFactory(
            self.como_version, self._task_registry)
        self._simulation_input_task_fac = SimulationInputTaskFactory(
            self.como_version, self._task_registry)
        self._simulation_task_fac = SimulationTaskFactory(
            self.como_version, self._task_registry)
        self._location_agg_task_fac = LocationAggTaskFactory(
            self.como_version, self._task_registry)
        self._summarize_task_fac = SummarizeTaskFactory(
            self.como_version, self._task_registry,
            self._location_agg_task_fac.agg_loc_set_map)
        self._pct_change_task_fac = PctChangeTaskFactory(
            self.como_version, self._task_registry,
            self._location_agg_task_fac.agg_loc_set_map)

        self.dag = TaskDag(
            name="COMO {}".format(self.como_version.como_version_id))
        self.redis_server = redis_server
Exemplo n.º 5
0
                          project=project,
                          runfile=runfile,
                          args=args,
                          process_timeout=process_timeout,
                          path_to_python_binary=path_to_python_binary,
                          upstream_tasks=upstream_tasks)


# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# Start up DAG
d = datetime.datetime.now()
dag_name = "mort_u5_{}_{}".format(version_id, d.strftime("%Y%m%d%H%M"))
dag = TaskDag(name=dag_name)

# Get locations
location_hierarchy = get_location_metadata(location_set_id=21, gbd_round_id=5)
location_hierarchy = location_hierarchy.loc[
    (location_hierarchy['level'] >= 3)
    & (location_hierarchy['location_id'] != 6)]
ihme_loc_dict = make_ihme_loc_id_dict(location_hierarchy)

all_files = glob.glob((draws_dir + "*").format(version_id))

# Create tasks
u5_tasks = {}
for location_id in location_hierarchy['location_id'].tolist():
    output_file = (draws_dir + "{}.csv").format(version_id, location_id)
    if output_file not in all_files:
Exemplo n.º 6
0
    def generate_dag(self, dag_name):
        # Start up DAG
        dag = TaskDag(name=dag_name)

        # Get input
        task_get_input = self.generate_get_input_task()
        dag.add_task(task_get_input)

        # Assess VR bias
        task_assess_vr_bias = self.generate_assess_bias_task([task_get_input])
        dag.add_task(task_assess_vr_bias)

        # Format data for prediction
        task_format_data = self.generate_format_data_task(
            [task_assess_vr_bias])
        dag.add_task(task_format_data)

        # Generate hyperparameters
        task_generate_hyperparameters = self.generate_generate_hyperparameters_task(
            [task_format_data])
        dag.add_task(task_generate_hyperparameters)

        # Get location data for next step
        location_data = get_location_metadata(location_set_id=82,
                                              gbd_round_id=self.gbd_round_id)
        ihme_loc_dict = make_ihme_loc_id_dict(location_data)

        # Run submodels
        task_submodel_a = {}
        task_submodel_b = {}
        task_submodel_c = {}
        task_variance = {}
        task_gpr = {}
        task_compile_gpr = {}
        output_files = []
        for s in self.config.submodels:
            # Submodel A: First stage regression
            task_submodel_a[s.submodel_id] = self.generate_submodel_a_task(
                s.submodel_id, s.input_location_ids, s.output_location_ids,
                [task_generate_hyperparameters])
            dag.add_task(task_submodel_a[s.submodel_id])

            # Submodel B: Bias correction
            task_submodel_b[s.submodel_id] = self.generate_submodel_b_task(
                s.submodel_id, s.input_location_ids, s.output_location_ids,
                [task_submodel_a[s.submodel_id]])
            dag.add_task(task_submodel_b[s.submodel_id])

            # Submodel C: Space-time smoothing
            task_submodel_c[s.submodel_id] = self.generate_submodel_c_task(
                s.submodel_id, s.input_location_ids, s.output_location_ids,
                [task_submodel_b[s.submodel_id]])
            dag.add_task(task_submodel_c[s.submodel_id])

            # Calculate variance
            task_variance[
                s.submodel_id] = self.generate_submodel_variance_task(
                    s.submodel_id, [task_submodel_c[s.submodel_id]])
            dag.add_task(task_variance[s.submodel_id])

            # GPR
            task_gpr[s.submodel_id] = {}
            for location_id in s.output_location_ids:
                ihme_loc_id = ihme_loc_dict[location_id]
                task_gpr[s.submodel_id][
                    location_id] = self.generate_submodel_gpr_task(
                        s.submodel_id, location_id, ihme_loc_id,
                        [task_variance[s.submodel_id]])
                dag.add_task(task_gpr[s.submodel_id][location_id])

            # Submodel compile
            task_compile_gpr[
                s.submodel_id] = self.generate_submodel_gpr_compile_task(
                    s.submodel_id, s.output_location_ids,
                    [v for k, v in task_gpr[s.submodel_id].items()])
            dag.add_task(task_compile_gpr[s.submodel_id])

        # Rake data
        task_rake = {}
        for r in self.config.rakings:
            # Submit raking jobs
            task_rake[r.raking_id] = self.generate_raking_task(
                r.raking_id, [s.submodel_id for s in self.config.submodels],
                r.parent_id, r.child_ids, r.direction,
                [v for k, v in task_compile_gpr.items()])
            dag.add_task(task_rake[r.raking_id])

        # Save draws and summary files in one spot
        task_save_draws = {}
        for s in self.config.submodels:
            for location_id in s.output_location_ids:
                task_save_draws[location_id] = self.generate_save_draws_task(
                    location_id, [v for k, v in task_rake.items()])
                dag.add_task(task_save_draws[location_id])

        # Upload
        task_upload = self.generate_upload_prep_task(
            [v for k, v in task_save_draws.items()])
        dag.add_task(task_upload)

        # Compile submodels together to be able to generate graphs
        task_compile_submodels = self.generate_compile_submodels_task(
            [v for k, v in task_save_draws.items()])
        dag.add_task(task_compile_submodels)

        # Generate comparisons of current location inputs to previous GBD
        task_generate_comparison = self.generate_comparison_task(
            [v for k, v in task_save_draws.items()])
        dag.add_task(task_generate_comparison)

        return dag
Exemplo n.º 7
0
class CoDCorrectJobSwarm(object):

    CONN_DEF_MAP = {
        'cod': {
            'dev': 'cod-test',
            'prod': 'cod'
        },
        'gbd': {
            'dev': 'gbd-test',
            'prod': 'gbd'
        }
    }
    STD_ERR = 'FILEPATH'
    STD_OUT = 'FILEPATH'

    def __init__(self,
                 code_dir,
                 version_id,
                 year_ids,
                 start_years,
                 location_set_ids,
                 databases,
                 db_env='dev'):
        """
        Arguments:
            code_dir (str): The directory containing CoDCorrect's code base.
            version_id (int): cod.output_version.output_version_id
            year_ids (int[]): set of years to run.
            start_years (int[]):
            location_set_ids (int[]):
            measure_ids (int[]):
            sex_ids (int[]):
            databases (str[]):
            db_env (str):
        """
        self.code_dir = code_dir
        self.version_id = version_id
        self.year_ids = year_ids
        self.start_years = start_years
        self.location_set_ids = location_set_ids
        self.databases = databases
        self.db_env = db_env

        self.most_detailed_locations = self.get_most_detailed_location_ids()
        self.all_locations = self.get_all_location_ids()

        self.measure_ids = [1, 4]
        self.sex_ids = [1, 2]
        self.pct_change = [True, False]

        self.task_dag = TaskDag(
            name=('CoDCorrect_v{}'.format(self.version_id)))
        self.shock_jobs_by_command = {}
        self.correct_jobs_by_command = {}
        self.agg_cause_jobs_by_command = {}
        self.ylls_jobs_by_command = {}
        self.agg_loc_jobs_by_command = {}
        self.append_shock_jobs_by_command = {}
        self.append_diag_jobs_by_command = {}
        self.summarize_jobs_by_command = {}
        self.upload_jobs_by_command = {}

    def calculate_slots_and_memory(self, years, base):
        """Determine the number of slots and memory to request.

        Arguments:
            years (int): list of years in run.
            base (int): base number of slots.

        Returns:
            Tuple representing the number of slots to request and the amount of
            memory.
        """
        slots = base * 6 if len(years) > 8 else base
        return slots, slots * 2

    def get_most_detailed_location_ids(self):
        locs = get_location_metadata(gbd_round_id=5, location_set_id=35)
        return locs.location_id.tolist()

    def get_all_location_ids(self):
        locs = []
        for loc_set in self.location_set_ids:
            locs.append(
                get_location_metadata(gbd_round_id=5, location_set_id=loc_set))
        locs = pd.concat(locs)
        all_locs = locs.location_id.unique().tolist()

        return all_locs

    def create_shock_and_correct_jobs(self):
        """First set of tasks, no upstream tasks."""
        slots, mem = self.calculate_slots_and_memory(self.year_ids, 3)
        for loc in self.most_detailed_locations:
            for sex in self.sex_ids:
                shock_task = PythonTask(
                    script=os.path.join(self.code_dir, 'shocks.py'),
                    args=[
                        '--output_version_id', self.version_id,
                        '--location_id', loc, '--sex_id', sex
                    ],
                    name='shocks_{version}_{loc}_{sex}'.format(
                        version=self.version_id, loc=loc, sex=sex),
                    slots=slots,
                    mem_free=mem,
                    max_attempts=3,
                    tag='shock')
                self.task_dag.add_task(shock_task)
                self.shock_jobs_by_command[shock_task.name] = shock_task

                correct_task = PythonTask(
                    script=os.path.join(self.code_dir, 'correct.py'),
                    args=[
                        '--output_version_id', self.version_id,
                        '--location_id', loc, '--sex_id', sex
                    ],
                    name='correct_{version}_{loc}_{sex}'.format(
                        version=self.version_id, loc=loc, sex=sex),
                    slots=slots,
                    mem_free=mem,
                    max_attempts=3,
                    tag='correct')
                self.task_dag.add_task(correct_task)
                self.correct_jobs_by_command[correct_task.name] = correct_task

    def create_agg_cause_jobs(self):
        slots, mem = self.calculate_slots_and_memory(self.year_ids, 7)
        for loc in self.most_detailed_locations:
            task = PythonTask(script=os.path.join(self.code_dir,
                                                  'aggregate_causes.py'),
                              args=[
                                  '--output_version_id', self.version_id,
                                  '--location_id', loc
                              ],
                              name='agg_cause_{version}_{loc}'.format(
                                  version=self.version_id, loc=loc),
                              slots=slots,
                              mem_free=mem,
                              max_attempts=3,
                              tag='agg_cause')
            # add shock/correct upstream dependencies
            for sex in self.sex_ids:
                task.add_upstream(self.shock_jobs_by_command[
                    'shocks_{version}_{loc}_{sex}'.format(
                        version=self.version_id, loc=loc, sex=sex)])
                task.add_upstream(self.correct_jobs_by_command[
                    'correct_{version}_{loc}_{sex}'.format(
                        version=self.version_id, loc=loc, sex=sex)])
            self.task_dag.add_task(task)
            self.agg_cause_jobs_by_command[task.name] = task

    def create_yll_jobs(self):
        slots, mem = self.calculate_slots_and_memory(self.year_ids, 3)
        for loc in self.most_detailed_locations:
            task = PythonTask(script=os.path.join(self.code_dir, 'ylls.py'),
                              args=[
                                  '--output_version_id', self.version_id,
                                  '--location_id', loc
                              ],
                              name='ylls_{version}_{loc}'.format(
                                  version=self.version_id, loc=loc),
                              slots=slots,
                              mem_free=mem,
                              max_attempts=3,
                              tag='ylls')
            # add cause_agg upstream dependencies
            task.add_upstream(self.agg_cause_jobs_by_command[
                'agg_cause_{version}_{loc}'.format(version=self.version_id,
                                                   loc=loc)])
            self.task_dag.add_task(task)
            self.ylls_jobs_by_command[task.name] = task

    def create_agg_location_jobs(self):
        slots, mem = 10, 100
        for loc_set in self.location_set_ids:
            for measure in self.measure_ids:
                for data_type in ['shocks', 'unscaled', 'rescaled']:
                    if data_type == 'unscaled' and measure == 4:
                        continue
                    for year_id in self.year_ids:
                        task = PythonTask(
                            script=os.path.join(self.code_dir,
                                                'aggregate_locations.py'),
                            args=[
                                '--output_version_id', self.version_id,
                                '--df_type', data_type, '--measure_id',
                                measure, '--location_set_id', loc_set,
                                '--year_id', year_id
                            ],
                            name=('agg_locations_{}_{}_{}_{}_{}'.format(
                                self.version_id, data_type, measure, loc_set,
                                year_id)),
                            slots=slots,
                            mem_free=mem,
                            max_attempts=5,
                            tag='agg_location')
                        for loc in self.most_detailed_locations:
                            if measure == 4:
                                task.add_upstream(self.ylls_jobs_by_command[
                                    'ylls_{}_{}'.format(self.version_id, loc)])
                            else:
                                task.add_upstream(
                                    self.agg_cause_jobs_by_command[
                                        'agg_cause_{}_{}'.format(
                                            self.version_id, loc)])
                        # Some of our special locations for final round
                        # estimates treat otherwise aggregated locations as
                        # most-detailed locations. This will throw an
                        # AssertionError in the aggregator if it cannot find
                        # the aggregate location's file. This if block ensures
                        # that the primary estimation location set (35) is run
                        # first before these special location aggregation jobs
                        # are run. This will slow down CoDCorrect overall.
                        if loc_set in SPECIAL_LOCATIONS:
                            task.add_upstream(self.agg_loc_jobs_by_command[
                                'agg_locations_{}_{}_{}_{}_{}'.format(
                                    self.version_id, data_type, measure, 35,
                                    year_id)])
                        self.task_dag.add_task(task)
                        self.agg_loc_jobs_by_command[task.name] = task

    def create_append_shock_jobs(self):
        slots, mem = self.calculate_slots_and_memory(self.year_ids, 7)
        for loc in self.all_locations:
            task = PythonTask(script=os.path.join(self.code_dir,
                                                  'append_shocks.py'),
                              args=[
                                  '--output_version_id', self.version_id,
                                  '--location_id', loc
                              ],
                              name='append_shocks_{version}_{loc}'.format(
                                  version=self.version_id, loc=loc),
                              slots=slots,
                              mem_free=mem,
                              max_attempts=3,
                              tag='append_shock')
            # for job in self.agg_loc_jobs_by_command.values():
            #     task.add_upstream(job)
            self.task_dag.add_task(task)
            self.append_shock_jobs_by_command[task.name] = task

    def create_summary_jobs(self):
        for loc in self.all_locations:
            for db in ['gbd', 'cod']:
                slots, mem = (15, 30) if db == 'cod' else (26, 52)
                task = PythonTask(script=os.path.join(self.code_dir,
                                                      'summary.py'),
                                  args=[
                                      '--output_version_id', self.version_id,
                                      '--location_id', loc, '--db', db
                                  ],
                                  name='summary_{version}_{loc}_{db}'.format(
                                      version=self.version_id, loc=loc, db=db),
                                  slots=slots,
                                  mem_free=mem,
                                  max_attempts=3,
                                  tag='summary')
                task.add_upstream(self.append_shock_jobs_by_command[
                    'append_shocks_{version}_{loc}'.format(
                        version=self.version_id, loc=loc)])
                self.task_dag.add_task(task)
                self.summarize_jobs_by_command[task.name] = task

    def create_append_diagnostic_jobs(self):
        slots, mem = (18, 36)
        task = PythonTask(script=os.path.join(self.code_dir,
                                              'append_diagnostics.py'),
                          args=['--output_version_id', self.version_id],
                          name='append_diagnostics_{version}'.format(
                              version=self.version_id),
                          slots=slots,
                          mem_free=mem,
                          max_attempts=3,
                          tag='append_diag')
        for job in self.append_shock_jobs_by_command.values():
            task.add_upstream(job)
        self.task_dag.add_task(task)
        self.append_diag_jobs_by_command[task.name] = task

    def create_upload_jobs(self):
        slots, mem = (10, 20)
        for measure in self.measure_ids:
            for db in self.databases:
                # cod and codcorrect databases only upload measure 1: deaths.
                if measure == 4 and db in ['cod', 'codcorrect']:
                    continue
                # cod and gbd db have separate test and production servers to
                # choose from. The codcorrect db doesn't have a test server
                if db in ['cod', 'gbd']:
                    conn_def = self.CONN_DEF_MAP[db][self.db_env]
                else:
                    conn_def = 'codcorrect'
                for change in self.pct_change:
                    # codcorrect & cod database does not upload for change.
                    if change and db in ['codcorrect', 'cod']:
                        continue
                    task = PythonTask(
                        script=os.path.join(self.code_dir, 'upload.py'),
                        args=[
                            '--output_version_id', self.version_id, '--db', db,
                            '--measure_id', measure, '--conn_def', conn_def,
                            '{}'.format('--change' if change else '')
                        ],
                        name='upload_{version}_{db}_{meas}_{change}'.format(
                            version=self.version_id,
                            db=db,
                            meas=measure,
                            change=change),
                        slots=slots,
                        mem_free=mem,
                        max_attempts=3,
                        tag='upload')
                    if db in ['cod', 'gbd']:
                        for loc in self.all_locations:
                            task.add_upstream(self.summarize_jobs_by_command[
                                'summary_{version}_{loc}_{db}'.format(
                                    version=self.version_id, loc=loc, db=db)])
                    else:
                        for job in self.append_diag_jobs_by_command.values():
                            task.add_upstream(job)
                    self.task_dag.add_task(task)

    def create_post_scriptum_upload(self):
        slots, mem = (1, 2)
        for db in self.databases:
            if db in ['cod', 'gbd']:
                task = PythonTask(
                    script=os.path.join(self.code_dir,
                                        'post_scriptum_upload.py'),
                    args=[
                        '--output_version_id', self.version_id, '--db', db,
                        '{}'.format('--test' if self.db_env == 'dev' else '')
                    ],
                    name=('post_scriptum_upload_{version}_{db}'.format(
                        version=self.version_id, db=db)),
                    slots=slots,
                    mem_free=mem,
                    max_attempts=1,
                    tag='post_scriptum_upload')
                upload_jobs = list(self.upload_jobs_by_command.values())
                for job in upload_jobs:
                    task.add_upstream(job)
                self.task_dag.add_task(task)

    def run(self):
        wf = Workflow(self.task_dag,
                      'codcorrect_v{}'.format(self.version_id),
                      stderr=self.STD_ERR,
                      stdout=self.STD_OUT,
                      project='proj_codcorrect')
        success = wf.run()
        return success

    def visualize(self):
        TaskDagViz(self.task_dag, graph_outdir='FILEPATH',
                   output_format='svg').render()
Exemplo n.º 8
0
class ComoWorkFlow(object):

    def __init__(self, como_version, redis_server=None):
        self.como_version = como_version

        # instantiate our factories
        self._task_registry = {}
        self._incidence_task_fac = IncidenceTaskFactory(
            self.como_version, self._task_registry)
        self._simulation_input_task_fac = SimulationInputTaskFactory(
            self.como_version, self._task_registry)
        self._simulation_task_fac = SimulationTaskFactory(
            self.como_version, self._task_registry)
        self._location_agg_task_fac = LocationAggTaskFactory(
            self.como_version, self._task_registry)
        self._summarize_task_fac = SummarizeTaskFactory(
            self.como_version, self._task_registry,
            self._location_agg_task_fac.agg_loc_set_map)
        self._pct_change_task_fac = PctChangeTaskFactory(
            self.como_version, self._task_registry,
            self._location_agg_task_fac.agg_loc_set_map)

        self.dag = TaskDag(
            name="COMO {}".format(self.como_version.como_version_id))
        self.redis_server = redis_server

    def _add_incidence_tasks(self):
        parallelism = ["location_id", "sex_id"]
        d = self.como_version.nonfatal_dimensions.get_simulation_dimensions(
            self.como_version.measure_id)
        for slices in d.index_slices(parallelism):
            incidence_task = self._incidence_task_fac.get_task(
                location_id=slices[0],
                sex_id=slices[1],
                n_processes=20)
            self.dag.add_task(incidence_task)

    def _add_simulation_input_tasks(self):
        parallelism = ["location_id", "sex_id"]
        d = self.como_version.nonfatal_dimensions.get_simulation_dimensions(
            self.como_version.measure_id)
        for slices in d.index_slices(parallelism):
            simulation_input_task = self._simulation_input_task_fac.get_task(
                location_id=slices[0],
                sex_id=slices[1],
                n_processes=23)
            self.dag.add_task(simulation_input_task)

    def _add_simulation_tasks(self, n_simulants):
        parallelism = ["location_id", "sex_id", "year_id"]
        d = self.como_version.nonfatal_dimensions.get_simulation_dimensions(
            self.como_version.measure_id)
        for slices in d.index_slices(parallelism):
            sim_task = self._simulation_task_fac.get_task(
                location_id=slices[0],
                sex_id=slices[1],
                year_id=slices[2],
                n_simulants=n_simulants,
                n_processes=23)
            self.dag.add_task(sim_task)

    def _add_loc_aggregation_tasks(self, agg_loc_set_versions):
        if self.redis_server is None:
            self.redis_server = RedisServer()
            self.redis_server.launch_redis_server()

        parallelism = ["year_id", "sex_id", "measure_id"]
        d = self.como_version.nonfatal_dimensions.get_simulation_dimensions(
            self.como_version.measure_id)
        for slices in d.index_slices(parallelism):
            for component in self.como_version.components:
                if not (component == "impairment" and slices[2] == 6):
                    for location_set_version_id in agg_loc_set_versions:
                        agg_task = self._location_agg_task_fac.get_task(
                            component=component,
                            year_id=slices[0],
                            sex_id=slices[1],
                            measure_id=slices[2],
                            location_set_version_id=location_set_version_id,
                            redis_host=self.redis_server.hostname)
                        self.dag.add_task(agg_task)

    def _add_summarization_tasks(self, agg_loc_set_versions):
        all_locs = []
        for location_set_version_id in agg_loc_set_versions:
            loc_tree = loctree(location_set_version_id=location_set_version_id)
            all_locs.extend(loc_tree.node_ids)
        all_locs = list(set(all_locs))

        parallelism = ["measure_id", "year_id"]
        d = self.como_version.nonfatal_dimensions.get_simulation_dimensions(
            self.como_version.measure_id)
        for slices in d.index_slices(parallelism):
            for location_id in all_locs:
                summ_task = self._summarize_task_fac.get_task(
                    measure_id=slices[0],
                    year_id=slices[1],
                    location_id=location_id)
                self.dag.add_task(summ_task)

    def _add_pct_change_tasks(self, agg_loc_set_versions):
        all_locs = []
        for location_set_version_id in agg_loc_set_versions:
            loc_tree = loctree(location_set_version_id=location_set_version_id)
            all_locs.extend(loc_tree.node_ids)
        all_locs = list(set(all_locs))

        for measure_id in self.como_version.measure_id:
            for location_id in all_locs:
                pct_change_task = self._pct_change_task_fac.get_task(
                    measure_id=measure_id,
                    location_id=location_id)
                self.dag.add_task(pct_change_task)

    def add_tasks_to_dag(self, n_simulants=20000, agg_loc_sets=[35, 83]):
        location_set_version_list = []
        for location_set_id in agg_loc_sets:
            location_set_version_list.append(
                active_location_set_version(
                    set_id=location_set_id,
                    gbd_round_id=self.como_version.gbd_round_id))

        if 6 in self.como_version.measure_id:
            self._add_incidence_tasks()
        self._add_simulation_input_tasks()
        self._add_simulation_tasks(n_simulants)
        self._add_loc_aggregation_tasks(location_set_version_list)
        self._add_summarization_tasks(location_set_version_list)
        if self.como_version.change_years:
            self._add_pct_change_tasks(location_set_version_list)

    def run_workflow(self, project="proj_como"):
        wf = Workflow(
            self.dag,
            workflow_args=self.como_version.como_dir,
            project=project)
        success = wf.run()
        self.redis_server.kill_redis_server()
        return success