Ejemplo n.º 1
0
def test_get_node_execution_outputs(mock_client_factory, execution_data_locations):
    mock_client = MagicMock()
    mock_client.get_node_execution_data = MagicMock(
        return_value=_execution_models.NodeExecutionGetDataResponse(
            execution_data_locations[0],
            execution_data_locations[1]
        )
    )
    mock_client_factory.return_value = mock_client

    m = MagicMock()
    type(m).id = PropertyMock(
        return_value=identifier.NodeExecutionIdentifier(
            "node-a",
            identifier.WorkflowExecutionIdentifier(
                "project",
                "domain",
                "name",
            )
        )
    )

    inputs = engine.FlyteNodeExecution(m).get_outputs()
    assert len(inputs.literals) == 1
    assert inputs.literals['b'].scalar.primitive.integer == 2
    mock_client.get_node_execution_data.assert_called_once_with(
        identifier.NodeExecutionIdentifier(
            "node-a",
            identifier.WorkflowExecutionIdentifier(
                "project",
                "domain",
                "name",
            )
        )
    )
Ejemplo n.º 2
0
def test_get_task_execution_inputs(mock_client_factory,
                                   execution_data_locations):
    mock_client = MagicMock()
    mock_client.get_task_execution_data = MagicMock(
        return_value=_execution_models.TaskExecutionGetDataResponse(
            execution_data_locations[0], execution_data_locations[1]))
    mock_client_factory.return_value = mock_client

    m = MagicMock()
    type(m).id = PropertyMock(return_value=identifier.TaskExecutionIdentifier(
        identifier.Identifier(identifier.ResourceType.TASK, 'project',
                              'domain', 'task-name', 'version'),
        identifier.NodeExecutionIdentifier(
            "node-a",
            identifier.WorkflowExecutionIdentifier(
                "project",
                "domain",
                "name",
            )), 0))

    inputs = engine.FlyteTaskExecution(m).get_inputs()
    assert len(inputs.literals) == 1
    assert inputs.literals['a'].scalar.primitive.integer == 1
    mock_client.get_task_execution_data.assert_called_once_with(
        identifier.TaskExecutionIdentifier(
            identifier.Identifier(identifier.ResourceType.TASK, 'project',
                                  'domain', 'task-name', 'version'),
            identifier.NodeExecutionIdentifier(
                "node-a",
                identifier.WorkflowExecutionIdentifier(
                    "project",
                    "domain",
                    "name",
                )), 0))
Ejemplo n.º 3
0
def test_get_full_node_execution_outputs(mock_client_factory):
    mock_client = MagicMock()
    mock_client.get_node_execution_data = MagicMock(
        return_value=_execution_models.NodeExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP)
    )
    mock_client_factory.return_value = mock_client

    m = MagicMock()
    type(m).id = PropertyMock(
        return_value=identifier.NodeExecutionIdentifier(
            "node-a",
            identifier.WorkflowExecutionIdentifier(
                "project",
                "domain",
                "name",
            ),
        )
    )

    outputs = engine.FlyteNodeExecution(m).get_outputs()
    assert len(outputs.literals) == 1
    assert outputs.literals["b"].scalar.primitive.integer == 2
    mock_client.get_node_execution_data.assert_called_once_with(
        identifier.NodeExecutionIdentifier(
            "node-a",
            identifier.WorkflowExecutionIdentifier(
                "project",
                "domain",
                "name",
            ),
        )
    )
Ejemplo n.º 4
0
def test_get_full_task_execution_inputs(mock_client_factory):
    mock_client = MagicMock()
    mock_client.get_task_execution_data = MagicMock(
        return_value=_execution_models.TaskExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP)
    )
    mock_client_factory.return_value = mock_client

    m = MagicMock()
    type(m).id = PropertyMock(
        return_value=identifier.TaskExecutionIdentifier(
            identifier.Identifier(
                identifier.ResourceType.TASK,
                "project",
                "domain",
                "task-name",
                "version",
            ),
            identifier.NodeExecutionIdentifier(
                "node-a",
                identifier.WorkflowExecutionIdentifier(
                    "project",
                    "domain",
                    "name",
                ),
            ),
            0,
        )
    )

    inputs = engine.FlyteTaskExecution(m).get_inputs()
    assert len(inputs.literals) == 1
    assert inputs.literals["a"].scalar.primitive.integer == 1
    mock_client.get_task_execution_data.assert_called_once_with(
        identifier.TaskExecutionIdentifier(
            identifier.Identifier(
                identifier.ResourceType.TASK,
                "project",
                "domain",
                "task-name",
                "version",
            ),
            identifier.NodeExecutionIdentifier(
                "node-a",
                identifier.WorkflowExecutionIdentifier(
                    "project",
                    "domain",
                    "name",
                ),
            ),
            0,
        )
    )
Ejemplo n.º 5
0
def test_execution_notification_soft_overrides(mock_client_factory):
    mock_client = MagicMock()
    mock_client.create_execution = MagicMock(
        return_value=identifier.WorkflowExecutionIdentifier('xp', 'xd', 'xn'))
    mock_client_factory.return_value = mock_client

    m = MagicMock()
    type(m).id = PropertyMock(return_value=identifier.Identifier(
        identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name",
        "version"))

    notification = _common_models.Notification(
        [0, 1, 2], email=_common_models.EmailNotification(["*****@*****.**"]))

    engine.FlyteLaunchPlan(m).execute('xp',
                                      'xd',
                                      'xn',
                                      literals.LiteralMap({}),
                                      notification_overrides=[notification])

    mock_client.create_execution.assert_called_once_with(
        'xp', 'xd', 'xn',
        _execution_models.ExecutionSpec(
            identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN,
                                  "project", "domain", "name", "version"),
            literals.LiteralMap({}),
            _execution_models.ExecutionMetadata(
                _execution_models.ExecutionMetadata.ExecutionMode.MANUAL,
                'sdk', 0),
            notifications=_execution_models.NotificationList([notification]),
        ))
Ejemplo n.º 6
0
 def initialize():
     """
     Re-initializes the context and erases the entire context
     """
     # This is supplied so that tasks that rely on Flyte provided param functionality do not fail when run locally
     default_execution_id = _identifier.WorkflowExecutionIdentifier(
         project="local", domain="local", name="local")
     # Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users
     # are already acquainted with
     default_user_space_params = ExecutionParameters(
         execution_id=str(
             _SdkWorkflowExecutionIdentifier.promote_from_model(
                 default_execution_id)),
         execution_date=_datetime.datetime.utcnow(),
         stats=_mock_stats.MockStats(),
         logging=_logging,
         tmp_dir=os.path.join(_sdk_config.LOCAL_SANDBOX.get(),
                              "user_space"),
     )
     default_context = FlyteContext(
         file_access=_data_proxy.default_local_file_access_provider)
     default_context = default_context.with_execution_state(
         default_context.new_execution_state().with_params(
             user_space_params=default_user_space_params)).build()
     default_context.set_stackframe(
         s=FlyteContextManager.get_origin_stackframe())
     FlyteContextManager._OBJS = [default_context]
Ejemplo n.º 7
0
    def launch(self, project, domain, name=None, inputs=None, notification_overrides=None, label_overrides=None,
               annotation_overrides=None, auth_role=None):
        """
        Executes the task as a single task execution and returns the identifier.
        :param Text project:
        :param Text domain:
        :param Text name:
        :param flytekit.models.literals.LiteralMap inputs: The inputs to pass
        :param list[flytekit.models.common.Notification] notification_overrides: If specified, override the
            notifications.
        :param flytekit.models.common.Labels label_overrides:
        :param flytekit.models.common.Annotations annotation_overrides:
        :param flytekit.models.common.AuthRole auth_role:
        :rtype: flytekit.models.execution.Execution
        """
        disable_all = (notification_overrides == [])
        if disable_all:
            notification_overrides = None
        else:
            notification_overrides = _execution_models.NotificationList(
                notification_overrides or []
            )
            disable_all = None

        if not auth_role:
            assumable_iam_role = _auth_config.ASSUMABLE_IAM_ROLE.get()
            kubernetes_service_account = _auth_config.KUBERNETES_SERVICE_ACCOUNT.get()

            if not (assumable_iam_role or kubernetes_service_account):
                _logging.warning("Using deprecated `role` from config. "
                                 "Please update your config to use `assumable_iam_role` instead")
                assumable_iam_role = _sdk_config.ROLE.get()
            auth_role = _common_models.AuthRole(assumable_iam_role=assumable_iam_role,
                                                kubernetes_service_account=kubernetes_service_account)

        try:
            # TODO(katrogan): Add handling to register the underlying task if it's not already.
            client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client
            exec_id = client.create_execution(
                project,
                domain,
                name,
                _execution_models.ExecutionSpec(
                    self.sdk_task.id,
                    _execution_models.ExecutionMetadata(
                        _execution_models.ExecutionMetadata.ExecutionMode.MANUAL,
                        'sdk',  # TODO: get principle
                        0  # TODO: Detect nesting
                    ),
                    notifications=notification_overrides,
                    disable_all=disable_all,
                    labels=label_overrides,
                    annotations=annotation_overrides,
                    auth_role=auth_role,
                ),
                inputs,
            )
        except _user_exceptions.FlyteEntityAlreadyExistsException:
            exec_id = _identifier.WorkflowExecutionIdentifier(project, domain, name)
        return client.get_execution(exec_id)
Ejemplo n.º 8
0
    def from_python_std(cls, string):
        """
        Parses a string in the correct format into an identifier
        :param Text string:
        :rtype: TaskExecutionIdentifier
        """
        segments = string.split(":")
        if len(segments) != 10:
            raise _user_exceptions.FlyteValueException(
                string,
                "The provided string was not in a parseable format. The string for an identifier must be in the format"
                " te:exec_project:exec_domain:exec_name:node_id:task_project:task_domain:task_name:task_version:retry.",
            )

        resource_type, ep, ed, en, node_id, tp, td, tn, tv, retry = segments

        if resource_type != "te":
            raise _user_exceptions.FlyteValueException(
                resource_type,
                "The provided string could not be parsed. The first element of an execution identifier must be 'ex'.",
            )

        return cls(
            task_id=Identifier(_core_identifier.ResourceType.TASK, tp, td, tn,
                               tv),
            node_execution_id=_core_identifier.NodeExecutionIdentifier(
                node_id=node_id,
                execution_id=_core_identifier.WorkflowExecutionIdentifier(
                    ep, ed, en),
            ),
            retry_attempt=int(retry),
        )
Ejemplo n.º 9
0
def test_task_node_metadata():
    task_id = identifier.Identifier(identifier.ResourceType.TASK, "project",
                                    "domain", "name", "version")
    wf_exec_id = identifier.WorkflowExecutionIdentifier(
        "project", "domain", "name")
    node_exec_id = identifier.NodeExecutionIdentifier(
        "node_id",
        wf_exec_id,
    )
    te_id = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3)
    ds_id = identifier.Identifier(identifier.ResourceType.TASK, "project",
                                  "domain", "t1", "abcdef")
    tag = catalog.CatalogArtifactTag("my-artifact-id", "some name")
    catalog_metadata = catalog.CatalogMetadata(dataset_id=ds_id,
                                               artifact_tag=tag,
                                               source_task_execution=te_id)

    obj = node_execution_models.TaskNodeMetadata(cache_status=0,
                                                 catalog_key=catalog_metadata)
    assert obj.cache_status == 0
    assert obj.catalog_key == catalog_metadata

    obj2 = node_execution_models.TaskNodeMetadata.from_flyte_idl(
        obj.to_flyte_idl())
    assert obj2 == obj
Ejemplo n.º 10
0
def test_execution_notification_overrides(mock_client_factory):
    mock_client = MagicMock()
    mock_client.create_execution = MagicMock(
        return_value=identifier.WorkflowExecutionIdentifier('xp', 'xd', 'xn'))
    mock_client_factory.return_value = mock_client

    m = MagicMock()
    type(m).id = PropertyMock(return_value=identifier.Identifier(
        identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name",
        "version"))

    engine.FlyteLaunchPlan(m).execute('xp',
                                      'xd',
                                      'xn',
                                      literals.LiteralMap({}),
                                      notification_overrides=[])

    mock_client.create_execution.assert_called_once_with(
        'xp', 'xd', 'xn',
        _execution_models.ExecutionSpec(
            identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN,
                                  "project", "domain", "name", "version"),
            literals.LiteralMap({}),
            _execution_models.ExecutionMetadata(
                _execution_models.ExecutionMetadata.ExecutionMode.MANUAL,
                'sdk', 0),
            disable_all=True,
        ))
Ejemplo n.º 11
0
    def launch(
        self,
        project,
        domain,
        name,
        inputs,
        notification_overrides=None,
        label_overrides=None,
        annotation_overrides=None,
    ):
        """
        Creates a workflow execution using parameters specified in the launch plan.
        :param Text project:
        :param Text domain:
        :param Text name:
        :param flytekit.models.literals.LiteralMap inputs:
        :param list[flytekit.models.common.Notification] notification_overrides: If specified, override the
            notifications.
        :param flytekit.models.common.Labels label_overrides:
        :param flytekit.models.common.Annotations annotation_overrides:
        :rtype: flytekit.models.execution.Execution
        """
        disable_all = notification_overrides == []
        if disable_all:
            notification_overrides = None
        else:
            notification_overrides = _execution_models.NotificationList(
                notification_overrides or [])
            disable_all = None

        try:
            client = _FlyteClientManager(
                _platform_config.URL.get(),
                insecure=_platform_config.INSECURE.get()).client
            exec_id = client.create_execution(
                project,
                domain,
                name,
                _execution_models.ExecutionSpec(
                    self.sdk_launch_plan.id,
                    _execution_models.ExecutionMetadata(
                        _execution_models.ExecutionMetadata.ExecutionMode.
                        MANUAL,
                        "sdk",  # TODO: get principle
                        0,  # TODO: Detect nesting
                    ),
                    notifications=notification_overrides,
                    disable_all=disable_all,
                    labels=label_overrides,
                    annotations=annotation_overrides,
                ),
                inputs,
            )
        except _user_exceptions.FlyteEntityAlreadyExistsException:
            exec_id = _identifier.WorkflowExecutionIdentifier(
                project, domain, name)
        return client.get_execution(exec_id)
Ejemplo n.º 12
0
def test_workflow_execution_identifier():
    identifier = _identifier.WorkflowExecutionIdentifier(
        "project", "domain", "name")
    assert identifier == _identifier.WorkflowExecutionIdentifier.from_urn(
        "ex:project:domain:name")
    assert identifier == _identifier.WorkflowExecutionIdentifier.promote_from_model(
        _core_identifier.WorkflowExecutionIdentifier("project", "domain",
                                                     "name"))
    assert identifier.__str__() == "ex:project:domain:name"
Ejemplo n.º 13
0
def test_node_execution_identifier():
    wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name")
    obj = identifier.NodeExecutionIdentifier("node_id", wf_exec_id)
    assert obj.node_id == "node_id"
    assert obj.execution_id == wf_exec_id

    obj2 = identifier.NodeExecutionIdentifier.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
    assert obj2.node_id == "node_id"
    assert obj2.execution_id == wf_exec_id
Ejemplo n.º 14
0
def test_workflow_node_metadata():
    wf_exec_id = identifier.WorkflowExecutionIdentifier(
        "project", "domain", "name")

    obj = node_execution_models.WorkflowNodeMetadata(execution_id=wf_exec_id)
    assert obj.execution_id is wf_exec_id

    obj2 = node_execution_models.WorkflowNodeMetadata.from_flyte_idl(
        obj.to_flyte_idl())
    assert obj == obj2
Ejemplo n.º 15
0
def test_workflow_execution_identifier():
    obj = identifier.WorkflowExecutionIdentifier("project", "domain", "name")
    assert obj.project == "project"
    assert obj.domain == "domain"
    assert obj.name == "name"

    obj2 = identifier.WorkflowExecutionIdentifier.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
    assert obj2.project == "project"
    assert obj2.domain == "domain"
    assert obj2.name == "name"
Ejemplo n.º 16
0
def test_get_execution_inputs(mock_client_factory, execution_data_locations):
    mock_client = MagicMock()
    mock_client.get_execution_data = MagicMock(
        return_value=_execution_models.WorkflowExecutionGetDataResponse(
            execution_data_locations[0], execution_data_locations[1],
            _EMPTY_LITERAL_MAP, _EMPTY_LITERAL_MAP))
    mock_client_factory.return_value = mock_client

    m = MagicMock()
    type(m).id = PropertyMock(
        return_value=identifier.WorkflowExecutionIdentifier(
            "project",
            "domain",
            "name",
        ))

    inputs = engine.FlyteWorkflowExecution(m).get_inputs()
    assert len(inputs.literals) == 1
    assert inputs.literals["a"].scalar.primitive.integer == 1
    mock_client.get_execution_data.assert_called_once_with(
        identifier.WorkflowExecutionIdentifier("project", "domain", "name"))
Ejemplo n.º 17
0
def test_exec_params():
    ep = ExecutionParameters(
        execution_id=id_models.WorkflowExecutionIdentifier("p", "d", "n"),
        task_id=id_models.Identifier(id_models.ResourceType.TASK, "local", "local", "local", "local"),
        execution_date=datetime.utcnow(),
        stats=mock_stats.MockStats(),
        logging=None,
        tmp_dir="/tmp",
        raw_output_prefix="",
        decks=[],
    )

    assert ep.task_id.name == "local"
Ejemplo n.º 18
0
def test_task_execution_identifier():
    task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version")
    wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name")
    node_exec_id = identifier.NodeExecutionIdentifier("node_id", wf_exec_id,)
    obj = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3)
    assert obj.retry_attempt == 3
    assert obj.task_id == task_id
    assert obj.node_execution_id == node_exec_id

    obj2 = identifier.TaskExecutionIdentifier.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
    assert obj2.retry_attempt == 3
    assert obj2.task_id == task_id
    assert obj2.node_execution_id == node_exec_id
Ejemplo n.º 19
0
def test_task_execution_identifier():
    task_id = _identifier.Identifier(_core_identifier.ResourceType.TASK,
                                     "project", "domain", "name", "version")
    node_execution_id = _core_identifier.NodeExecutionIdentifier(
        node_id="n0",
        execution_id=_core_identifier.WorkflowExecutionIdentifier(
            "project", "domain", "name"))
    identifier = _identifier.TaskExecutionIdentifier(
        task_id=task_id,
        node_execution_id=node_execution_id,
        retry_attempt=0,
    )
    assert identifier == _identifier.TaskExecutionIdentifier.from_urn(
        "te:project:domain:name:n0:project:domain:name:version:0")
    assert identifier == _identifier.TaskExecutionIdentifier.promote_from_model(
        _core_identifier.TaskExecutionIdentifier(task_id, node_execution_id,
                                                 0))
    assert identifier.__str__(
    ) == "te:project:domain:name:n0:project:domain:name:version:0"
Ejemplo n.º 20
0
def test_execution_annotation_overrides(mock_client_factory):
    mock_client = MagicMock()
    mock_client.create_execution = MagicMock(
        return_value=identifier.WorkflowExecutionIdentifier("xp", "xd", "xn"))
    mock_client_factory.return_value = mock_client

    m = MagicMock()
    type(m).id = PropertyMock(return_value=identifier.Identifier(
        identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name",
        "version"))

    annotations = _common_models.Annotations({"my": "annotation"})
    engine.FlyteLaunchPlan(m).launch(
        "xp",
        "xd",
        "xn",
        literals.LiteralMap({}),
        notification_overrides=[],
        annotation_overrides=annotations,
    )

    mock_client.create_execution.assert_called_once_with(
        "xp",
        "xd",
        "xn",
        _execution_models.ExecutionSpec(
            identifier.Identifier(
                identifier.ResourceType.LAUNCH_PLAN,
                "project",
                "domain",
                "name",
                "version",
            ),
            _execution_models.ExecutionMetadata(
                _execution_models.ExecutionMetadata.ExecutionMode.MANUAL,
                "sdk", 0),
            disable_all=True,
            annotations=annotations,
        ),
        literals.LiteralMap({}),
    )
Ejemplo n.º 21
0
    def initialize():
        """
        Re-initializes the context and erases the entire context
        """
        # This is supplied so that tasks that rely on Flyte provided param functionality do not fail when run locally
        default_execution_id = _identifier.WorkflowExecutionIdentifier(
            project="local", domain="local", name="local")

        cfg = Config.auto()
        # Ensure a local directory is available for users to work with.
        user_space_path = os.path.join(cfg.local_sandbox_path, "user_space")
        pathlib.Path(user_space_path).mkdir(parents=True, exist_ok=True)

        # Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users
        # are already acquainted with
        default_context = FlyteContext(
            file_access=default_local_file_access_provider)
        default_user_space_params = ExecutionParameters(
            execution_id=WorkflowExecutionIdentifier.promote_from_model(
                default_execution_id),
            task_id=_identifier.Identifier(_identifier.ResourceType.TASK,
                                           "local", "local", "local", "local"),
            execution_date=_datetime.datetime.utcnow(),
            stats=mock_stats.MockStats(),
            logging=user_space_logger,
            tmp_dir=user_space_path,
            raw_output_prefix=default_context.file_access._raw_output_prefix,
            decks=[],
        )

        default_context = default_context.with_execution_state(
            default_context.new_execution_state().with_params(
                user_space_params=default_user_space_params)).build()
        default_context.set_stackframe(
            s=FlyteContextManager.get_origin_stackframe())
        flyte_context_Var.set([default_context])
Ejemplo n.º 22
0
def setup_execution(
    raw_output_data_prefix: str,
    checkpoint_path: Optional[str] = None,
    prev_checkpoint: Optional[str] = None,
    dynamic_addl_distro: Optional[str] = None,
    dynamic_dest_dir: Optional[str] = None,
):
    """

    :param raw_output_data_prefix:
    :param checkpoint_path:
    :param prev_checkpoint:
    :param dynamic_addl_distro: Works in concert with the other dynamic arg. If present, indicates that if a dynamic
      task were to run, it should set fast serialize to true and use these values in FastSerializationSettings
    :param dynamic_dest_dir: See above.
    :return:
    """
    exe_project = get_one_of("FLYTE_INTERNAL_EXECUTION_PROJECT", "_F_PRJ")
    exe_domain = get_one_of("FLYTE_INTERNAL_EXECUTION_DOMAIN", "_F_DM")
    exe_name = get_one_of("FLYTE_INTERNAL_EXECUTION_ID", "_F_NM")
    exe_wf = get_one_of("FLYTE_INTERNAL_EXECUTION_WORKFLOW", "_F_WF")
    exe_lp = get_one_of("FLYTE_INTERNAL_EXECUTION_LAUNCHPLAN", "_F_LP")

    tk_project = get_one_of("FLYTE_INTERNAL_TASK_PROJECT", "_F_TK_PRJ")
    tk_domain = get_one_of("FLYTE_INTERNAL_TASK_DOMAIN", "_F_TK_DM")
    tk_name = get_one_of("FLYTE_INTERNAL_TASK_NAME", "_F_TK_NM")
    tk_version = get_one_of("FLYTE_INTERNAL_TASK_VERSION", "_F_TK_V")

    compressed_serialization_settings = os.environ.get(SERIALIZED_CONTEXT_ENV_VAR, "")

    ctx = FlyteContextManager.current_context()
    # Create directories
    user_workspace_dir = ctx.file_access.get_random_local_directory()
    logger.info(f"Using user directory {user_workspace_dir}")
    pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True)
    from flytekit import __version__ as _api_version

    checkpointer = None
    if checkpoint_path is not None:
        checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint)
        logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}")

    execution_parameters = ExecutionParameters(
        execution_id=_identifier.WorkflowExecutionIdentifier(
            project=exe_project,
            domain=exe_domain,
            name=exe_name,
        ),
        execution_date=_datetime.datetime.utcnow(),
        stats=_get_stats(
            cfg=StatsConfig.auto(),
            # Stats metric path will be:
            # registration_project.registration_domain.app.module.task_name.user_stats
            # and it will be tagged with execution-level values for project/domain/wf/lp
            prefix=f"{tk_project}.{tk_domain}.{tk_name}.user_stats",
            tags={
                "exec_project": exe_project,
                "exec_domain": exe_domain,
                "exec_workflow": exe_wf,
                "exec_launchplan": exe_lp,
                "api_version": _api_version,
            },
        ),
        logging=user_space_logger,
        tmp_dir=user_workspace_dir,
        raw_output_prefix=raw_output_data_prefix,
        checkpoint=checkpointer,
        task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version),
    )

    try:
        file_access = FileAccessProvider(
            local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"),
            raw_output_prefix=raw_output_data_prefix,
        )
    except TypeError:  # would be thrown from DataPersistencePlugins.find_plugin
        logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}")
        raise

    es = ctx.new_execution_state().with_params(
        mode=ExecutionState.Mode.TASK_EXECUTION,
        user_space_params=execution_parameters,
    )
    cb = ctx.new_builder().with_file_access(file_access).with_execution_state(es)

    if compressed_serialization_settings:
        ss = SerializationSettings.from_transport(compressed_serialization_settings)
        ssb = ss.new_builder()
        ssb.project = exe_project
        ssb.domain = exe_domain
        ssb.version = tk_version
        if dynamic_addl_distro:
            ssb.fast_serialization_settings = FastSerializationSettings(
                enabled=True,
                destination_dir=dynamic_dest_dir,
                distribution_location=dynamic_addl_distro,
            )
        cb = cb.with_serialization_settings(ssb.build())

    with FlyteContextManager.with_context(cb) as ctx:
        yield ctx
Ejemplo n.º 23
0
    def execute(self, inputs, context=None):
        """
        Just execute the task and write the outputs to where they belong
        :param flytekit.models.literals.LiteralMap inputs:
        :param dict[Text, Text] context:
        :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
        """

        with _common_utils.AutoDeletingTempDir("engine_dir") as temp_dir:
            with _common_utils.AutoDeletingTempDir("task_dir") as task_dir:
                with _data_proxy.LocalWorkingDirectoryContext(task_dir):
                    with _data_proxy.RemoteDataContext():
                        output_file_dict = dict()

                        # This sets the logging level for user code and is the only place an sdk setting gets
                        # used at runtime.  Optionally, Propeller can set an internal config setting which
                        # takes precedence.
                        log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get()
                        _logging.getLogger().setLevel(log_level)

                        try:
                            output_file_dict = self.sdk_task.execute(
                                _common_engine.EngineContext(
                                    execution_id=_identifier.WorkflowExecutionIdentifier(
                                        project=_internal_config.EXECUTION_PROJECT.get(),
                                        domain=_internal_config.EXECUTION_DOMAIN.get(),
                                        name=_internal_config.EXECUTION_NAME.get()
                                    ),
                                    execution_date=_datetime.utcnow(),
                                    stats=_get_stats(
                                        # Stats metric path will be:
                                        # registration_project.registration_domain.app.module.task_name.user_stats
                                        # and it will be tagged with execution-level values for project/domain/wf/lp
                                        "{}.{}.{}.user_stats".format(
                                            _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(),
                                            _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(),
                                            _internal_config.TASK_NAME.get() or _internal_config.NAME.get()
                                        ),
                                        tags={
                                            'exec_project': _internal_config.EXECUTION_PROJECT.get(),
                                            'exec_domain': _internal_config.EXECUTION_DOMAIN.get(),
                                            'exec_workflow': _internal_config.EXECUTION_WORKFLOW.get(),
                                            'exec_launchplan': _internal_config.EXECUTION_LAUNCHPLAN.get(),
                                            'api_version': _api_version
                                        }
                                    ),
                                    logging=_logging,
                                    tmp_dir=task_dir
                                ),
                                inputs
                            )
                        except _exception_scopes.FlyteScopedException as e:
                            _logging.error("!!! Begin Error Captured by Flyte !!!")
                            output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
                                _error_models.ContainerError(
                                    e.error_code,
                                    e.verbose_message,
                                    e.kind
                                )
                            )
                            _logging.error(e.verbose_message)
                            _logging.error("!!! End Error Captured by Flyte !!!")
                        except Exception:
                            _logging.error("!!! Begin Unknown System Error Captured by Flyte !!!")
                            exc_str = _traceback.format_exc()
                            output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
                                _error_models.ContainerError(
                                    "SYSTEM:Unknown",
                                    exc_str,
                                    _error_models.ContainerError.Kind.RECOVERABLE
                                )
                            )
                            _logging.error(exc_str)
                            _logging.error("!!! End Error Captured by Flyte !!!")
                        finally:
                            for k, v in _six.iteritems(output_file_dict):
                                _common_utils.write_proto_to_file(
                                    v.to_flyte_idl(),
                                    _os.path.join(temp_dir.name, k)
                                )
                            _data_proxy.Data.put_data(temp_dir.name, context['output_prefix'], is_multipart=True)
Ejemplo n.º 24
0
        if self._flyte_client is not None:
            return self._flyte_client
        elif self._parent is not None:
            return self._parent.flyte_client
        else:
            raise Exception("No flyte_client initialized")


# Hack... we'll think of something better in the future
class FlyteEntities(object):
    entities = []


# This is supplied so that tasks that rely on Flyte provided param functionality do not fail when run locally
default_execution_id = _identifier.WorkflowExecutionIdentifier(project="local",
                                                               domain="local",
                                                               name="local")
# Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users
# are already acquainted with
default_user_space_params = ExecutionParameters(
    execution_id=str(
        _SdkWorkflowExecutionIdentifier.promote_from_model(
            default_execution_id)),
    execution_date=_datetime.datetime.utcnow(),
    stats=_mock_stats.MockStats(),
    logging=_logging,
    tmp_dir=os.path.join(_sdk_config.LOCAL_SANDBOX.get(), "user_space"),
)
default_context = FlyteContext(
    user_space_params=default_user_space_params,
    file_access=_data_proxy.default_local_file_access_provider)
Ejemplo n.º 25
0
def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str, raw_output_data_prefix: str):
    """
    Entrypoint for all PythonTask extensions
    """
    _click.echo("Running native-typed task")
    cloud_provider = _platform_config.CLOUD_PROVIDER.get()
    log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get()
    _logging.getLogger().setLevel(log_level)

    ctx = FlyteContext.current_context()

    # Create directories
    user_workspace_dir = ctx.file_access.local_access.get_random_directory()
    _click.echo(f"Using user directory {user_workspace_dir}")
    pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True)
    from flytekit import __version__ as _api_version

    execution_parameters = ExecutionParameters(
        execution_id=_identifier.WorkflowExecutionIdentifier(
            project=_internal_config.EXECUTION_PROJECT.get(),
            domain=_internal_config.EXECUTION_DOMAIN.get(),
            name=_internal_config.EXECUTION_NAME.get(),
        ),
        execution_date=_datetime.datetime.utcnow(),
        stats=_get_stats(
            # Stats metric path will be:
            # registration_project.registration_domain.app.module.task_name.user_stats
            # and it will be tagged with execution-level values for project/domain/wf/lp
            "{}.{}.{}.user_stats".format(
                _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(),
                _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(),
                _internal_config.TASK_NAME.get() or _internal_config.NAME.get(),
            ),
            tags={
                "exec_project": _internal_config.EXECUTION_PROJECT.get(),
                "exec_domain": _internal_config.EXECUTION_DOMAIN.get(),
                "exec_workflow": _internal_config.EXECUTION_WORKFLOW.get(),
                "exec_launchplan": _internal_config.EXECUTION_LAUNCHPLAN.get(),
                "api_version": _api_version,
            },
        ),
        logging=_logging,
        tmp_dir=user_workspace_dir,
    )

    if cloud_provider == _constants.CloudProvider.AWS:
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix),
        )
    elif cloud_provider == _constants.CloudProvider.GCP:
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix),
        )
    elif cloud_provider == _constants.CloudProvider.LOCAL:
        # A fake remote using the local disk will automatically be created
        file_access = _data_proxy.FileAccessProvider(local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get())
    else:
        raise Exception(f"Bad cloud provider {cloud_provider}")

    with ctx.new_file_access_context(file_access_provider=file_access) as ctx:
        # TODO: This is copied from serialize, which means there's a similarity here I'm not seeing.
        env = {
            _internal_config.CONFIGURATION_PATH.env_var: _internal_config.CONFIGURATION_PATH.get(),
            _internal_config.IMAGE.env_var: _internal_config.IMAGE.get(),
        }

        serialization_settings = SerializationSettings(
            project=_internal_config.TASK_PROJECT.get(),
            domain=_internal_config.TASK_DOMAIN.get(),
            version=_internal_config.TASK_VERSION.get(),
            image_config=get_image_config(),
            env=env,
        )

        # The reason we need this is because of dynamic tasks. Even if we move compilation all to Admin,
        # if a dynamic task calls some task, t1, we have to write to the DJ Spec the correct task
        # identifier for t1.
        with ctx.new_serialization_settings(serialization_settings=serialization_settings) as ctx:
            # Because execution states do not look up the context chain, it has to be made last
            with ctx.new_execution_context(
                mode=ExecutionState.Mode.TASK_EXECUTION, execution_params=execution_parameters
            ) as ctx:
                _dispatch_execute(ctx, task_def, inputs, output_prefix)