Exemplo n.º 1
0
 def initialize():
     """
     Re-initializes the context and erases the entire context
     """
     # This is supplied so that tasks that rely on Flyte provided param functionality do not fail when run locally
     default_execution_id = _identifier.WorkflowExecutionIdentifier(
         project="local", domain="local", name="local")
     # Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users
     # are already acquainted with
     default_user_space_params = ExecutionParameters(
         execution_id=str(
             _SdkWorkflowExecutionIdentifier.promote_from_model(
                 default_execution_id)),
         execution_date=_datetime.datetime.utcnow(),
         stats=_mock_stats.MockStats(),
         logging=_logging,
         tmp_dir=os.path.join(_sdk_config.LOCAL_SANDBOX.get(),
                              "user_space"),
     )
     default_context = FlyteContext(
         file_access=_data_proxy.default_local_file_access_provider)
     default_context = default_context.with_execution_state(
         default_context.new_execution_state().with_params(
             user_space_params=default_user_space_params)).build()
     default_context.set_stackframe(
         s=FlyteContextManager.get_origin_stackframe())
     FlyteContextManager._OBJS = [default_context]
Exemplo n.º 2
0
    def pre_execute(self,
                    user_params: ExecutionParameters) -> ExecutionParameters:
        import pyspark as _pyspark

        ctx = FlyteContextManager.current_context()
        sess_builder = _pyspark.sql.SparkSession.builder.appName(
            f"FlyteSpark: {user_params.execution_id}")
        if not (ctx.execution_state and ctx.execution_state.Mode
                == ExecutionState.Mode.TASK_EXECUTION):
            # If either of above cases is not true, then we are in local execution of this task
            # Add system spark-conf for local/notebook based execution.
            spark_conf = _pyspark.SparkConf()
            for k, v in self.task_config.spark_conf.items():
                spark_conf.set(k, v)
            # In local execution, propagate PYTHONPATH to executors too. This makes the spark
            # execution hermetic to the execution environment. For example, it allows running
            # Spark applications using Bazel, without major changes.
            if "PYTHONPATH" in os.environ:
                spark_conf.setExecutorEnv("PYTHONPATH",
                                          os.environ["PYTHONPATH"])
            sess_builder = sess_builder.config(conf=spark_conf)

        self.sess = sess_builder.getOrCreate()
        return user_params.builder().add_attr("SPARK_SESSION",
                                              self.sess).build()
Exemplo n.º 3
0
def test_spark_task():
    @task(task_config=Spark(spark_conf={"spark": "1"}))
    def my_spark(a: str) -> int:
        session = flytekit.current_context().spark_session
        assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
        return 10

    assert my_spark.task_config is not None
    assert my_spark.task_config.spark_conf == {"spark": "1"}

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    retrieved_settings = my_spark.get_custom(settings)
    assert retrieved_settings["sparkConf"] == {"spark": "1"}

    pb = ExecutionParameters.new_builder()
    pb.working_dir = "/tmp"
    pb.execution_id = "ex:local:local:local"
    p = pb.build()
    new_p = my_spark.pre_execute(p)
    assert new_p is not None
    assert new_p.has_attr("SPARK_SESSION")

    assert my_spark.sess is not None
    configs = my_spark.sess.sparkContext.getConf().getAll()
    assert ("spark", "1") in configs
    assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs
Exemplo n.º 4
0
    def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
        import pyspark as _pyspark

        ctx = FlyteContext.current_context()
        if not (ctx.execution_state and ctx.execution_state.Mode == ExecutionState.Mode.TASK_EXECUTION):
            # If either of above cases is not true, then we are in local execution of this task
            # Add system spark-conf for local/notebook based execution.
            spark_conf = set()
            for k, v in self.task_config.spark_conf.items():
                spark_conf.add((k, v))
            spark_conf.add(("spark.master", "local"))
            _pyspark.SparkConf().setAll(spark_conf)

        sess = _pyspark.sql.SparkSession.builder.appName(f"FlyteSpark: {user_params.execution_id}").getOrCreate()
        return user_params.builder().add_attr("SPARK_SESSION", sess).build()
Exemplo n.º 5
0
    def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
        """
        Pre-execute for Sagemaker will automatically add the distributed context to the execution params, only
        if the number of execution instances is > 1. Otherwise this is considered to be a single node execution
        """
        if self._is_distributed():
            logging.info("Distributed context detected!")
            exec_state = FlyteContext.current_context().execution_state
            if exec_state and exec_state.mode == ExecutionState.Mode.TASK_EXECUTION:
                """
                    This mode indicates we are actually in a remote execute environment (within sagemaker in this case)
                """
                dist_ctx = DistributedTrainingContext.from_env()
            else:
                dist_ctx = DistributedTrainingContext.local_execute()
            return user_params.builder().add_attr("DISTRIBUTED_TRAINING_CONTEXT", dist_ctx).build()

        return user_params
Exemplo n.º 6
0
def test_distributed_custom_training():
    setup_envars_for_testing()

    @task(task_config=SagemakerTrainingJobConfig(
        training_job_resource_config=TrainingJobResourceConfig(
            instance_type="ml-xlarge",
            volume_size_in_gb=1,
            instance_count=2,  # Indicates distributed training
            distributed_protocol=DistributedProtocol.MPI,
        ),
        algorithm_specification=AlgorithmSpecification(
            algorithm_name=AlgorithmName.CUSTOM, ),
    ))
    def my_custom_trainer(x: int) -> int:
        assert flytekit.current_context(
        ).distributed_training_context is not None
        return x

    assert my_custom_trainer.python_interface.inputs == {"x": int}
    assert my_custom_trainer.python_interface.outputs == {"o0": int}

    assert my_custom_trainer(x=10) == 10

    assert my_custom_trainer._is_distributed() is True

    pb = ExecutionParameters.new_builder()
    pb.working_dir = "/tmp"
    p = pb.build()
    new_p = my_custom_trainer.pre_execute(p)
    assert new_p is not None
    assert new_p.has_attr("distributed_training_context")

    assert my_custom_trainer.get_custom(_get_reg_settings()) == {
        "algorithmSpecification": {},
        "trainingJobResourceConfig": {
            "distributedProtocol": "MPI",
            "instanceCount": "2",
            "instanceType": "ml-xlarge",
            "volumeSizeInGb": "1",
        },
    }
Exemplo n.º 7
0
        elif self._parent is not None:
            return self._parent.flyte_client
        else:
            raise Exception("No flyte_client initialized")


# Hack... we'll think of something better in the future
class FlyteEntities(object):
    entities = []


# This is supplied so that tasks that rely on Flyte provided param functionality do not fail when run locally
default_execution_id = _identifier.WorkflowExecutionIdentifier(project="local",
                                                               domain="local",
                                                               name="local")
# Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users
# are already acquainted with
default_user_space_params = ExecutionParameters(
    execution_id=str(
        _SdkWorkflowExecutionIdentifier.promote_from_model(
            default_execution_id)),
    execution_date=_datetime.datetime.utcnow(),
    stats=_mock_stats.MockStats(),
    logging=_logging,
    tmp_dir=os.path.join(_sdk_config.LOCAL_SANDBOX.get(), "user_space"),
)
default_context = FlyteContext(
    user_space_params=default_user_space_params,
    file_access=_data_proxy.default_local_file_access_provider)
FlyteContext.OBJS.append(default_context)
Exemplo n.º 8
0
def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str, raw_output_data_prefix: str):
    """
    Entrypoint for all PythonTask extensions
    """
    _click.echo("Running native-typed task")
    cloud_provider = _platform_config.CLOUD_PROVIDER.get()
    log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get()
    _logging.getLogger().setLevel(log_level)

    ctx = FlyteContext.current_context()

    # Create directories
    user_workspace_dir = ctx.file_access.local_access.get_random_directory()
    _click.echo(f"Using user directory {user_workspace_dir}")
    pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True)
    from flytekit import __version__ as _api_version

    execution_parameters = ExecutionParameters(
        execution_id=_identifier.WorkflowExecutionIdentifier(
            project=_internal_config.EXECUTION_PROJECT.get(),
            domain=_internal_config.EXECUTION_DOMAIN.get(),
            name=_internal_config.EXECUTION_NAME.get(),
        ),
        execution_date=_datetime.datetime.utcnow(),
        stats=_get_stats(
            # Stats metric path will be:
            # registration_project.registration_domain.app.module.task_name.user_stats
            # and it will be tagged with execution-level values for project/domain/wf/lp
            "{}.{}.{}.user_stats".format(
                _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(),
                _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(),
                _internal_config.TASK_NAME.get() or _internal_config.NAME.get(),
            ),
            tags={
                "exec_project": _internal_config.EXECUTION_PROJECT.get(),
                "exec_domain": _internal_config.EXECUTION_DOMAIN.get(),
                "exec_workflow": _internal_config.EXECUTION_WORKFLOW.get(),
                "exec_launchplan": _internal_config.EXECUTION_LAUNCHPLAN.get(),
                "api_version": _api_version,
            },
        ),
        logging=_logging,
        tmp_dir=user_workspace_dir,
    )

    if cloud_provider == _constants.CloudProvider.AWS:
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix),
        )
    elif cloud_provider == _constants.CloudProvider.GCP:
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix),
        )
    elif cloud_provider == _constants.CloudProvider.LOCAL:
        # A fake remote using the local disk will automatically be created
        file_access = _data_proxy.FileAccessProvider(local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get())
    else:
        raise Exception(f"Bad cloud provider {cloud_provider}")

    with ctx.new_file_access_context(file_access_provider=file_access) as ctx:
        # TODO: This is copied from serialize, which means there's a similarity here I'm not seeing.
        env = {
            _internal_config.CONFIGURATION_PATH.env_var: _internal_config.CONFIGURATION_PATH.get(),
            _internal_config.IMAGE.env_var: _internal_config.IMAGE.get(),
        }

        serialization_settings = SerializationSettings(
            project=_internal_config.TASK_PROJECT.get(),
            domain=_internal_config.TASK_DOMAIN.get(),
            version=_internal_config.TASK_VERSION.get(),
            image_config=get_image_config(),
            env=env,
        )

        # The reason we need this is because of dynamic tasks. Even if we move compilation all to Admin,
        # if a dynamic task calls some task, t1, we have to write to the DJ Spec the correct task
        # identifier for t1.
        with ctx.new_serialization_settings(serialization_settings=serialization_settings) as ctx:
            # Because execution states do not look up the context chain, it has to be made last
            with ctx.new_execution_context(
                mode=ExecutionState.Mode.TASK_EXECUTION, execution_params=execution_parameters
            ) as ctx:
                _dispatch_execute(ctx, task_def, inputs, output_prefix)