def main(
    git_branch,
    codebuild_id,
    pipeline_name,
    model_name,
    deploy_role,
    sagemaker_role,
    sagemaker_bucket,
    data_dir,
    output_dir,
    ecr_dir,
    kms_key_id,
    workflow_role_arn,
    notification_arn,
    sagemaker_project_id,
    tags,
):
    # Define the function names
    create_experiment_function_name = "mlops-create-experiment"
    query_training_function_name = "mlops-query-training"

    # Get the region
    region = boto3.Session().region_name
    print("region: {}".format(region))

    if ecr_dir:
        # Load the image uri and input data config
        with open(os.path.join(ecr_dir, "imageDetail.json"), "r") as f:
            image_uri = json.load(f)["ImageURI"]
    else:
        # Get the the managed image uri for current region
        image_uri = get_training_image(region)
    print("image uri: {}".format(image_uri))

    with open(os.path.join(data_dir, "inputData.json"), "r") as f:
        input_data = json.load(f)
        print("training uri: {}".format(input_data["TrainingUri"]))
        print("validation uri: {}".format(input_data["ValidationUri"]))
        print("baseline uri: {}".format(input_data["BaselineUri"]))

    # Get the job id and source revisions
    job_id = get_pipeline_execution_id(pipeline_name, codebuild_id)
    revisions = get_pipeline_revisions(pipeline_name, job_id)
    git_commit_id = revisions["ModelSourceOutput"]
    data_verison_id = revisions["DataSourceOutput"]
    print("job id: {}".format(job_id))
    print("git commit: {}".format(git_commit_id))
    print("data version: {}".format(data_verison_id))

    # Set the output Data
    output_data = {
        "ModelOutputUri":
        "s3://{}/{}".format(sagemaker_bucket, model_name),
        "BaselineOutputUri":
        f"s3://{sagemaker_bucket}/{model_name}/monitoring/baseline/{model_name}-pbl-{job_id}",
    }
    print("model output uri: {}".format(output_data["ModelOutputUri"]))

    # Pass these into the training method
    hyperparameters = {}
    if os.path.exists(os.path.join(data_dir, "hyperparameters.json")):
        with open(os.path.join(data_dir, "hyperparameters.json"), "r") as f:
            hyperparameters = json.load(f)
            for i in hyperparameters:
                hyperparameters[i] = str(hyperparameters[i])

    # Define the step functions execution input schema
    execution_input = ExecutionInput(
        schema={
            "GitBranch": str,
            "GitCommitHash": str,
            "DataVersionId": str,
            "ExperimentName": str,
            "TrialName": str,
            "BaselineJobName": str,
            "BaselineOutputUri": str,
            "TrainingJobName": str,
        })

    # Create experiment step
    experiment_step = create_experiment_step(create_experiment_function_name)
    baseline_step = create_baseline_step(input_data, execution_input, region,
                                         sagemaker_role)
    training_step = create_training_step(
        image_uri,
        hyperparameters,
        input_data,
        output_data,
        execution_input,
        query_training_function_name,
        region,
        sagemaker_role,
    )
    workflow_definition = create_graph(experiment_step, baseline_step,
                                       training_step)

    # Create the workflow as the model name
    workflow = Workflow(model_name, workflow_definition, workflow_role_arn)
    print("Creating workflow: {0}-{1}".format(model_name,
                                              sagemaker_project_id))

    # Create output directory
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    # Write the workflow graph to json
    with open(os.path.join(output_dir, "workflow-graph.json"), "w") as f:
        f.write(workflow.definition.to_json(pretty=True))

    # Write the workflow graph to yml
    with open(os.path.join(output_dir, "workflow-graph.yml"), "w") as f:
        f.write(workflow.get_cloudformation_template())

    # Write the workflow inputs to file
    with open(os.path.join(output_dir, "workflow-input.json"), "w") as f:
        workflow_inputs = {
            "ExperimentName": "{}".format(model_name),
            "TrialName": "{}-{}".format(model_name, job_id),
            "GitBranch": git_branch,
            "GitCommitHash": git_commit_id,
            "DataVersionId": data_verison_id,
            "BaselineJobName": "{}-pbl-{}".format(model_name, job_id),
            "BaselineOutputUri": output_data["BaselineOutputUri"],
            "TrainingJobName": "{}-{}".format(model_name, job_id),
        }
        json.dump(workflow_inputs, f)

    # Write the dev & prod params for CFN
    with open(os.path.join(output_dir, "deploy-model-dev.json"), "w") as f:
        config = get_dev_config(model_name, job_id, deploy_role, image_uri,
                                kms_key_id, sagemaker_project_id)
        json.dump(config, f)
    with open(os.path.join(output_dir, "deploy-model-prd.json"), "w") as f:
        config = get_prd_config(
            model_name,
            job_id,
            deploy_role,
            image_uri,
            kms_key_id,
            notification_arn,
            sagemaker_project_id,
        )
        json.dump(config, f)
class StepfunctionsWorkflow(DataJobBase):
    """Class that defines the methods to create and execute an orchestration using the step functions sdk.

    example:

        with StepfunctionsWorkflow("techskills-parser") as tech_skills_parser_orchestration:

            some-glue-job-1 >> [some-glue-job-2,some-glue-job-3] >> some-glue-job-4

        tech_skills_parser_orchestration.execute()

    """
    def __init__(
        self,
        datajob_stack: core.Construct,
        name: str,
        role: iam.Role = None,
        region: str = None,
        **kwargs,
    ):
        super().__init__(datajob_stack, name, **kwargs)
        self.chain_of_tasks = []
        self.workflow = None
        self.role = (self.get_role(unique_name=self.unique_name,
                                   service_principal="states.amazonaws.com")
                     if role is None else role)
        self.region = (region if region is not None else
                       os.environ.get("AWS_DEFAULT_REGION"))

    def add_task(self, task_other):
        """add a task to the workflow we would like to orchestrate."""
        job_name = task_other.unique_name
        logger.debug(f"adding task with name {job_name}")
        task = StepfunctionsWorkflow._create_glue_start_job_run_step(
            job_name=job_name)
        self.chain_of_tasks.append(task)

    def add_parallel_tasks(self, task_others):
        """add tasks in parallel (wrapped in a list) to the workflow we would like to orchestrate."""
        deploy_pipelines = Parallel(state_id=uuid.uuid4().hex)
        for one_other_task in task_others:
            task_unique_name = one_other_task.unique_name
            logger.debug(f"adding parallel task with name {task_unique_name}")
            deploy_pipelines.add_branch(
                StepfunctionsWorkflow._create_glue_start_job_run_step(
                    job_name=task_unique_name))
        self.chain_of_tasks.append(deploy_pipelines)

    @staticmethod
    def _create_glue_start_job_run_step(job_name):
        logger.debug("creating a step for a glue job.")
        return GlueStartJobRunStep(job_name,
                                   wait_for_completion=True,
                                   parameters={"JobName": job_name})

    def _build_workflow(self):
        """create a step functions workflow from the chain_of_tasks."""
        logger.debug(
            f"creating a chain from all the different steps. \n {self.chain_of_tasks}"
        )
        workflow_definition = steps.Chain(self.chain_of_tasks)
        logger.debug(f"creating a workflow with name {self.unique_name}")
        self.client = boto3.client("stepfunctions")
        self.workflow = Workflow(
            name=self.unique_name,
            definition=workflow_definition,
            role=self.role.role_arn,
            client=self.client,
        )

    def create(self):
        """create sfn stack"""
        with tempfile.TemporaryDirectory() as tmp_dir:
            sfn_cf_file_path = str(Path(tmp_dir, self.unique_name))
            with open(sfn_cf_file_path, "w") as text_file:
                text_file.write(self.workflow.get_cloudformation_template())
            cfn_inc.CfnInclude(self,
                               self.unique_name,
                               template_file=sfn_cf_file_path)

    def __enter__(self):
        """first steps we have to do when entering the context manager."""
        logger.info(f"creating step functions workflow for {self.unique_name}")
        _set_workflow(self)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """steps we have to do when exiting the context manager."""
        self._build_workflow()
        _set_workflow(None)
        logger.info(f"step functions workflow {self.unique_name} created")