예제 #1
0
    def outputs(self) -> Dict[str, Any]:
        """
        Returns the outputs to the execution in the standard python format as dictated by the type engine.

        :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error.
        """
        if not self.is_complete:
            raise _user_exceptions.FlyteAssertion(
                "Please wait until the node execution has completed before requesting the outputs."
            )
        if self.error:
            raise _user_exceptions.FlyteAssertion(
                "Outputs could not be found because the execution ended in failure."
            )

        if self._outputs is None:
            client = _flyte_engine.get_client()
            execution_data = client.get_execution_data(self.id)
            # Outputs are returned inline unless they are too big, in which case a url blob pointing to them is returned.
            output_map: _literal_models.LiteralMap = _literal_models.LiteralMap(
                {})
            if bool(execution_data.full_outputs.literals):
                output_map = execution_data.full_outputs
            elif execution_data.outputs.bytes > 0:
                with _common_utils.AutoDeletingTempDir() as tmp_dir:
                    tmp_name = _os.path.join(tmp_dir.name, "outputs.pb")
                    _data_proxy.Data.get_data(execution_data.outputs.url,
                                              tmp_name)
                    output_map = _literal_models.LiteralMap.from_flyte_idl(
                        _common_utils.load_proto_from_file(
                            _literals_pb2.LiteralMap, tmp_name))

            lp_id = self.spec.launch_plan
            workflow = _workflow.FlyteWorkflow.fetch(lp_id.project,
                                                     lp_id.domain, lp_id.name,
                                                     lp_id.version)
            self._outputs = TypeEngine.literal_map_to_kwargs(
                ctx=FlyteContextManager.current_context(),
                lm=output_map,
                python_types=TypeEngine.guess_python_types(
                    workflow.interface.outputs),
            )
        return self._outputs
예제 #2
0
def test_old_style_role():
    identifier_model = identifier.Identifier(identifier.ResourceType.TASK,
                                             "project", "domain", "name",
                                             "version")

    s = schedule.Schedule("asdf", "1 3 4 5 6 7")
    launch_plan_metadata_model = launch_plan.LaunchPlanMetadata(
        schedule=s, notifications=[])

    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           "asdf asdf asdf")
    p = interface.Parameter(var=v)
    parameter_map = interface.ParameterMap({"ppp": p})

    fixed_inputs = literals.LiteralMap({
        "a":
        literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(
            integer=1)))
    })

    labels_model = common.Labels({})
    annotations_model = common.Annotations({"my": "annotation"})

    raw_data_output_config = common.RawOutputDataConfig("s3://bucket")

    old_role = _launch_plan_idl.Auth(
        kubernetes_service_account="my:service:account")

    old_style_spec = _launch_plan_idl.LaunchPlanSpec(
        workflow_id=identifier_model.to_flyte_idl(),
        entity_metadata=launch_plan_metadata_model.to_flyte_idl(),
        default_inputs=parameter_map.to_flyte_idl(),
        fixed_inputs=fixed_inputs.to_flyte_idl(),
        labels=labels_model.to_flyte_idl(),
        annotations=annotations_model.to_flyte_idl(),
        raw_output_data_config=raw_data_output_config.to_flyte_idl(),
        auth=old_role,
    )

    lp_spec = launch_plan.LaunchPlanSpec.from_flyte_idl(old_style_spec)

    assert lp_spec.auth_role.assumable_iam_role == "my:service:account"
예제 #3
0
    def compile_into_workflow(
        self, ctx: FlyteContext, task_function: Callable, **kwargs
    ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]:
        with ctx.new_compilation_context(prefix="dynamic"):
            # TODO: Resolve circular import
            from flytekit.common.translator import get_serializable

            workflow_metadata = WorkflowMetadata(
                on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY)
            defaults = WorkflowMetadataDefaults(interruptible=False)

            self._wf = Workflow(task_function,
                                metadata=workflow_metadata,
                                default_metadata=defaults)
            self._wf.compile(**kwargs)

            wf = self._wf
            sdk_workflow = get_serializable(ctx.serialization_settings, wf)

            # If no nodes were produced, let's just return the strict outputs
            if len(sdk_workflow.nodes) == 0:
                return _literal_models.LiteralMap(
                    literals={
                        binding.var: binding.binding.to_literal_model()
                        for binding in sdk_workflow._outputs
                    })

            # Gather underlying tasks/workflows that get referenced. Launch plans are handled by propeller.
            tasks = set()
            sub_workflows = set()
            for n in sdk_workflow.nodes:
                self.aggregate(tasks, sub_workflows, n)

            dj_spec = _dynamic_job.DynamicJobSpec(
                min_successes=len(sdk_workflow.nodes),
                tasks=list(tasks),
                nodes=sdk_workflow.nodes,
                outputs=sdk_workflow._outputs,
                subworkflows=list(sub_workflows),
            )

            return dj_spec
예제 #4
0
    def _execute_user_code(self, inputs):
        """
        :param flytekit.models.literals.LiteralMap inputs:
        :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
        """
        results = super(DynamicTask, self)._execute_user_code(inputs)
        if _sdk_constants.FUTURES_FILE_NAME in results:
            futures = results[_sdk_constants.FUTURES_FILE_NAME]
            sub_task_outputs = {}
            tasks_map = {task.id: task for task in futures.tasks}

            for future_node in futures.nodes:
                task = tasks_map[future_node.task_node.reference_id]
                if task.type == _sdk_constants.SdkTaskType.CONTAINER_ARRAY_TASK:
                    sub_task_output = DynamicTask.execute_array_task(
                        future_node.id, task, results)
                elif task.type == _sdk_constants.SdkTaskType.HIVE_JOB:
                    # TODO: futures.outputs should have the Schema instances.
                    # After schema is implemented, fill out random data into the random locations
                    # then check output in test function
                    # Even though we recommend people use typed schemas, they might not always do so...
                    # in which case it'll be impossible to predict the actual schema, we should support a
                    # way for unit test authors to provide fake data regardless
                    sub_task_output = None
                else:
                    inputs_path = _os.path.join(future_node.id,
                                                _sdk_constants.INPUT_FILE_NAME)
                    if inputs_path not in results:
                        raise _system_exception.FlyteSystemAssertion(
                            "dynamic task hasn't generated expected inputs document [{}] found {}"
                            .format(future_node.id, list(results.keys())))
                    sub_task_output = UnitTestEngineFactory().get_task(
                        task).execute(results[inputs_path])
                sub_task_outputs[future_node.id] = sub_task_output

            results[_sdk_constants.OUTPUT_FILE_NAME] = _literals.LiteralMap(
                literals={
                    binding.var: DynamicTask.fulfil_bindings(
                        binding.binding, sub_task_outputs)
                    for binding in futures.outputs
                })
        return results
예제 #5
0
    def execute(self, context, inputs):
        """
        :param flytekit.engines.common.EngineContext context:
        :param flytekit.models.literals.LiteralMap inputs:
        :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
        :returns: This function must return a dictionary mapping 'filenames' to Flyte Interface Entities.  These
            entities will be used by the engine to pass data from node to node, populate metadata, etc. etc..  Each
            engine will have different behavior.  For instance, the Flyte engine will upload the entities to a remote
            working directory (with the names provided), which will in turn allow Flyte Propeller to push along the
            workflow.  Where as local engine will merely feed the outputs directly into the next node.
        """
        inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std(
            inputs,
            {
                k: _type_helpers.get_sdk_type_from_literal_type(v.type)
                for k, v in _six.iteritems(self.interface.inputs)
            },
        )
        outputs_dict = {
            name: _task_output.OutputReference(
                _type_helpers.get_sdk_type_from_literal_type(variable.type))
            for name, variable in _six.iteritems(self.interface.outputs)
        }

        inputs_dict.update(outputs_dict)

        with GlobalSparkContext():
            _exception_scopes.user_entry_point(self.task_function)(
                _sdk_runnable.ExecutionParameters(
                    execution_date=context.execution_date,
                    tmp_dir=context.working_directory,
                    stats=context.stats,
                    execution_id=context.execution_id,
                    logging=context.logging,
                ), GlobalSparkContext.get_spark_context(), **inputs_dict)
        return {
            _constants.OUTPUT_FILE_NAME:
            _literal_models.LiteralMap(literals={
                k: v.sdk_value
                for k, v in _six.iteritems(outputs_dict)
            })
        }
예제 #6
0
    def setUp(self):
        with _utils.AutoDeletingTempDir("input_dir") as input_dir:
            self._task_input = _literals.LiteralMap({
                "input_1":
                _literals.Literal(scalar=_literals.Scalar(
                    primitive=_literals.Primitive(integer=1)))
            })

            self._context = _common_engine.EngineContext(
                execution_id=WorkflowExecutionIdentifier(project="unit_test",
                                                         domain="unit_test",
                                                         name="unit_test"),
                execution_date=_datetime.datetime.utcnow(),
                stats=MockStats(),
                logging=None,
                tmp_dir=input_dir.name,
            )

            # Defining the distributed training task without specifying an output-persist
            # predicate (so it will use the default)
            @inputs(input_1=Types.Integer)
            @outputs(model=Types.Blob)
            @custom_training_job_task(
                training_job_resource_config=TrainingJobResourceConfig(
                    instance_type="ml.m4.xlarge",
                    instance_count=2,
                    volume_size_in_gb=25,
                ),
                algorithm_specification=AlgorithmSpecification(
                    input_mode=InputMode.FILE,
                    input_content_type=InputContentType.TEXT_CSV,
                    metric_definitions=[
                        MetricDefinition(name="Validation error",
                                         regex="validation:error")
                    ],
                ),
            )
            def my_distributed_task(wf_params, input_1, model):
                pass

            self._my_distributed_task = my_distributed_task
            assert type(self._my_distributed_task) == CustomTrainingJobTask
예제 #7
0
def test_launch_workflow_with_args(flyteclient, flyte_workflows_register):
    execution = launch_plan.FlyteLaunchPlan.fetch(
        PROJECT, "development", "workflows.basic.basic_workflow.my_wf",
        f"v{VERSION}").launch_with_literals(
            PROJECT,
            "development",
            literals.LiteralMap({
                "a":
                literals.Literal(
                    literals.Scalar(literals.Primitive(integer=10))),
                "b":
                literals.Literal(
                    literals.Scalar(
                        literals.Primitive(string_value="foobar"))),
            }),
        )
    execution.wait_for_completion()
    assert execution.outputs.literals["o0"].scalar.primitive.integer == 12
    assert execution.outputs.literals[
        "o1"].scalar.primitive.string_value == "foobarworld"
예제 #8
0
def test_monitor_workflow(flyteclient, flyte_workflows_register):
    execution = launch_plan.FlyteLaunchPlan.fetch(
        PROJECT, "development", "workflows.basic.hello_world.my_wf",
        f"v{VERSION}").launch_with_literals(PROJECT, "development",
                                            literals.LiteralMap({}))

    poll_interval = datetime.timedelta(seconds=1)
    time_to_give_up = datetime.datetime.utcnow() + datetime.timedelta(
        seconds=60)

    execution.sync()
    while datetime.datetime.utcnow() < time_to_give_up:

        if execution.is_complete:
            execution.sync()
            break

        with pytest.raises(
                FlyteAssertion,
                match=
                "Please wait until the node execution has completed before requesting the outputs"
        ):
            execution.outputs

        time.sleep(poll_interval.total_seconds())
        execution.sync()

        if execution.node_executions:
            assert execution.node_executions[
                "start-node"].closure.phase == 3  # SUCCEEEDED

    for key in execution.node_executions:
        assert execution.node_executions[key].closure.phase == 3

    assert execution.node_executions["n0"].inputs == {}
    assert execution.node_executions["n0"].outputs["o0"] == "hello world"
    assert execution.node_executions["n0"].task_executions[0].inputs == {}
    assert execution.node_executions["n0"].task_executions[0].outputs[
        "o0"] == "hello world"
    assert execution.inputs == {}
    assert execution.outputs["o0"] == "hello world"
예제 #9
0
def test_dispatch_execute_ignore(mock_write_to_file, mock_upload_dir,
                                 mock_get_data, mock_load_proto):
    # Just leave these here, mock them out so nothing happens
    mock_get_data.return_value = True
    mock_upload_dir.return_value = True
    ctx = context_manager.FlyteContext.current_context()

    # IgnoreOutputs
    with context_manager.FlyteContextManager.with_context(
            ctx.with_execution_state(
                ctx.execution_state.with_params(
                    mode=context_manager.ExecutionState.Mode.TASK_EXECUTION))
    ) as ctx:
        python_task = mock.MagicMock()
        python_task.dispatch_execute.side_effect = IgnoreOutputs()

        empty_literal_map = _literal_models.LiteralMap({}).to_flyte_idl()
        mock_load_proto.return_value = empty_literal_map

        _dispatch_execute(ctx, python_task, "inputs path", "outputs prefix")
        assert mock_write_to_file.call_count == 0
예제 #10
0
    def _local_execute(
            self, ctx: FlyteContext,
            **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
        """
        This code is used only in the case when we want to dispatch_execute with outputs from a previous node
        For regular execution, dispatch_execute is invoked directly.
        """
        # Unwrap the kwargs values. After this, we essentially have a LiteralMap
        # The reason why we need to do this is because the inputs during local execute can be of 2 types
        #  - Promises or native constants
        #  Promises as essentially inputs from previous task executions
        #  native constants are just bound to this specific task (default values for a task input)
        #  Also alongwith promises and constants, there could be dictionary or list of promises or constants
        kwargs = translate_inputs_to_literals(
            ctx,
            incoming_values=kwargs,
            flyte_interface_types=self.interface.inputs,
            native_types=self.get_input_types(),
        )
        input_literal_map = _literal_models.LiteralMap(literals=kwargs)

        outputs_literal_map = self.dispatch_execute(ctx, input_literal_map)
        outputs_literals = outputs_literal_map.literals

        # TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
        #    location, otherwise we dont really need to right? The higher level execute could just handle literalMap
        # After running, we again have to wrap the outputs, if any, back into Promise objects
        output_names = list(self.interface.outputs.keys())
        if len(output_names) != len(outputs_literals):
            # Length check, clean up exception
            raise AssertionError(
                f"Length difference {len(output_names)} {len(outputs_literals)}"
            )

        # Tasks that don't return anything still return a VoidPromise
        if len(output_names) == 0:
            return VoidPromise(self.name)

        vals = [Promise(var, outputs_literals[var]) for var in output_names]
        return create_task_output(vals, self.python_interface)
예제 #11
0
    def unwrap_literal_map_and_execute(
        self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
    ) -> Union[VoidPromise, _literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
        """
        Please see the implementation of the dispatch_execute function in the real task.
        """

        # Invoked before the task is executed
        # Translate the input literals to Python native
        native_inputs = TypeEngine.literal_map_to_kwargs(ctx, input_literal_map, self.python_interface.inputs)

        logger.info(f"Invoking {self.name} with inputs: {native_inputs}")
        try:
            native_outputs = self.execute(**native_inputs)
        except Exception as e:
            logger.exception(f"Exception when executing {e}")
            raise e
        logger.info(f"Task executed successfully in user level, outputs: {native_outputs}")

        expected_output_names = list(self.python_interface.outputs.keys())
        if len(expected_output_names) == 1:
            native_outputs_as_map = {expected_output_names[0]: native_outputs}
        elif len(expected_output_names) == 0:
            native_outputs_as_map = {}
        else:
            native_outputs_as_map = {expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs)}

        # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
        # built into the IDL that all the values of a literal map are of the same type.
        literals = {}
        for k, v in native_outputs_as_map.items():
            literal_type = self.interface.outputs[k].type
            py_type = self.python_interface.outputs[k]
            if isinstance(v, tuple):
                raise AssertionError(f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}")
            literals[k] = TypeEngine.to_literal(ctx, v, py_type, literal_type)
        outputs_literal_map = _literal_models.LiteralMap(literals=literals)
        # After the execute has been successfully completed
        return outputs_literal_map
예제 #12
0
파일: engine.py 프로젝트: vglocus/flytekit
    def get_outputs(self):
        """
        :rtype: flytekit.models.literals.LiteralMap
        """
        client = _FlyteClientManager(
            _platform_config.URL.get(),
            insecure=_platform_config.INSECURE.get()).client
        execution_data = client.get_task_execution_data(
            self.sdk_task_execution.id)

        # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned.
        if bool(execution_data.full_outputs.literals):
            return execution_data.full_outputs

        if execution_data.outputs.bytes > 0:
            with _common_utils.AutoDeletingTempDir() as t:
                tmp_name = _os.path.join(t.name, "outputs.pb")
                _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name)
                return _literals.LiteralMap.from_flyte_idl(
                    _common_utils.load_proto_from_file(
                        _literals_pb2.LiteralMap, tmp_name))
        return _literals.LiteralMap({})
예제 #13
0
def test_hive_task_dynamic_job_spec_generation():
    with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory:
        context = _common_engine.EngineContext(
            execution_id=WorkflowExecutionIdentifier(
                project='unit_test',
                domain='unit_test',
                name='unit_test'
            ),
            execution_date=_datetime.utcnow(),
            stats=None,  # TODO: A mock stats object that we can read later.
            logging=_logging,  # TODO: A mock logging object that we can read later.
            tmp_dir=user_working_directory
        )
        dj_spec = two_queries._produce_dynamic_job_spec(context, _literals.LiteralMap(literals={}))

        # Bindings
        assert len(dj_spec.outputs[0].binding.collection.bindings) == 2
        assert isinstance(dj_spec.outputs[0].binding.collection.bindings[0].scalar.schema, Schema)
        assert isinstance(dj_spec.outputs[0].binding.collection.bindings[1].scalar.schema, Schema)

        # Custom field is filled in
        assert len(dj_spec.tasks[0].custom) > 0
예제 #14
0
def get_promise(binding_data: _literal_models.BindingData,
                outputs_cache: Dict[Node, Dict[str, Promise]]) -> Promise:
    """
    This is a helper function that will turn a binding into a Promise object, using a lookup map. Please see
    get_promise_map for the rest of the details.
    """
    if binding_data.promise is not None:
        if not isinstance(binding_data.promise, NodeOutput):
            raise FlyteValidationException(
                f"Binding data Promises have to be of the NodeOutput type {type(binding_data.promise)} found"
            )
        # b.var is the name of the input to the task
        # binding_data.promise.var is the name of the upstream node's output we want
        return outputs_cache[binding_data.promise.node][
            binding_data.promise.var]
    elif binding_data.scalar is not None:
        return Promise(var="placeholder",
                       val=_literal_models.Literal(scalar=binding_data.scalar))
    elif binding_data.collection is not None:
        literals = []
        for bd in binding_data.collection.bindings:
            p = get_promise(bd, outputs_cache)
            literals.append(p.val)
        return Promise(
            var="placeholder",
            val=_literal_models.Literal(
                collection=_literal_models.LiteralCollection(
                    literals=literals)),
        )
    elif binding_data.map is not None:
        literals = {}
        for k, bd in binding_data.map.bindings.items():
            p = get_promise(bd, outputs_cache)
            literals[k] = p.val
        return Promise(var="placeholder",
                       val=_literal_models.Literal(
                           map=_literal_models.LiteralMap(literals=literals)))

    raise FlyteValidationException("Binding type unrecognized.")
예제 #15
0
def test_dispatch_execute_exception(mock_write_to_file, mock_upload_dir,
                                    mock_get_data, mock_load_proto):
    # Just leave these here, mock them out so nothing happens
    mock_get_data.return_value = True
    mock_upload_dir.return_value = True

    ctx = context_manager.FlyteContext.current_context()
    with ctx.new_execution_context(
            mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) as ctx:

        python_task = mock.MagicMock()
        python_task.dispatch_execute.side_effect = Exception("random")

        empty_literal_map = _literal_models.LiteralMap({}).to_flyte_idl()
        mock_load_proto.return_value = empty_literal_map

        def verify_output(*args, **kwargs):
            assert isinstance(args[0], ErrorDocument)

        mock_write_to_file.side_effect = verify_output
        _dispatch_execute(ctx, python_task, "inputs path", "outputs prefix")
        assert mock_write_to_file.call_count == 1
예제 #16
0
    def fulfil_bindings(binding_data, fulfilled_promises):
        """
        Substitutes promise values in binding_data with model Literal values built from python std values in
        fulfilled_promises

        :param _interface.BindingData binding_data:
        :param dict[Text,T] fulfilled_promises:
        :rtype:
        """
        if binding_data.scalar:
            return _literals.Literal(scalar=binding_data.scalar)
        elif binding_data.collection:
            return _literals.Literal(collection=_literals.LiteralCollection([
                DynamicTask.fulfil_bindings(sub_binding_data,
                                            fulfilled_promises)
                for sub_binding_data in binding_data.collection.bindings
            ]))
        elif binding_data.promise:
            if binding_data.promise.node_id not in fulfilled_promises:
                raise _system_exception.FlyteSystemAssertion(
                    "Expecting output of node [{}] but that hasn't been produced."
                    .format(binding_data.promise.node_id))
            node_output = fulfilled_promises[binding_data.promise.node_id]
            if binding_data.promise.var not in node_output:
                raise _system_exception.FlyteSystemAssertion(
                    "Expecting output [{}] of node [{}] but that hasn't been produced."
                    .format(binding_data.promise.var,
                            binding_data.promise.node_id))

            return binding_data.promise.sdk_type.from_python_std(
                node_output[binding_data.promise.var])
        elif binding_data.map:
            return _literals.Literal(map=_literals.LiteralMap({
                k: DynamicTask.fulfil_bindings(sub_binding_data,
                                               fulfilled_promises)
                for k, sub_binding_data in _six.iteritems(
                    binding_data.map.bindings)
            }))
예제 #17
0
    def _local_execute(
            self, ctx: FlyteContext,
            **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
        """
        Please see the _local_execute comments in the main task.
        """
        # Unwrap the kwargs values. After this, we essentially have a LiteralMap
        # The reason why we need to do this is because the inputs during local execute can be of 2 types
        #  - Promises or native constants
        #  Promises as essentially inputs from previous task executions
        #  native constants are just bound to this specific task (default values for a task input)
        #  Also alongwith promises and constants, there could be dictionary or list of promises or constants
        kwargs = translate_inputs_to_literals(
            ctx,
            incoming_values=kwargs,
            flyte_interface_types=self.interface.inputs,
            native_types=self.python_interface.inputs,
        )
        input_literal_map = _literal_models.LiteralMap(literals=kwargs)

        outputs_literal_map = self.unwrap_literal_map_and_execute(
            ctx, input_literal_map)

        # After running, we again have to wrap the outputs, if any, back into Promise objects
        outputs_literals = outputs_literal_map.literals
        output_names = list(self.python_interface.outputs.keys())
        if len(output_names) != len(outputs_literals):
            # Length check, clean up exception
            raise AssertionError(
                f"Length difference {len(output_names)} {len(outputs_literals)}"
            )

        # Tasks that don't return anything still return a VoidPromise
        if len(output_names) == 0:
            return VoidPromise(self.name)

        vals = [Promise(var, outputs_literals[var]) for var in output_names]
        return create_task_output(vals, self.python_interface)
예제 #18
0
def test_dispatch_execute_ignore(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto):
    # Just leave these here, mock them out so nothing happens
    mock_get_data.return_value = True
    mock_upload_dir.return_value = True
    ctx = context_manager.FlyteContext.current_context()

    # IgnoreOutputs
    with context_manager.FlyteContextManager.with_context(
        ctx.with_execution_state(
            ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
        )
    ) as ctx:
        python_task = mock.MagicMock()
        python_task.dispatch_execute.side_effect = IgnoreOutputs()

        empty_literal_map = _literal_models.LiteralMap({}).to_flyte_idl()
        mock_load_proto.return_value = empty_literal_map

        # The system_entry_point decorator does different thing based on whether or not it's the
        # first time it's called. Using it here to mimic the fact that _dispatch_execute is
        # called by _execute_task, which also has a system_entry_point
        system_entry_point(_dispatch_execute)(ctx, python_task, "inputs path", "outputs prefix")
        assert mock_write_to_file.call_count == 0
예제 #19
0
    def outputs(self):
        """
        Returns the outputs to the execution in the standard Python format as dictated by the type engine.  If the
        execution ended in error or the execution is in progress, an exception will be raised.
        :rtype: dict[Text, T]
        """
        if not self.is_complete:
            raise _user_exceptions.FlyteAssertion(
                "Please what until the node execution has completed before requesting the outputs."
            )
        if self.error:
            raise _user_exceptions.FlyteAssertion(
                "Outputs could not be found because the execution ended in failure."
            )

        if self._outputs is None:
            client = _flyte_engine.get_client()
            execution_data = client.get_node_execution_data(self.id)

            # Outputs are returned inline unless they are too big, in which case a url blob pointing to them is returned.
            if bool(execution_data.full_outputs.literals):
                output_map = execution_data.full_outputs

            elif execution_data.outputs.bytes > 0:
                with _common_utils.AutoDeletingTempDir() as t:
                    tmp_name = _os.path.join(t.name, "outputs.pb")
                    _data_proxy.Data.get_data(execution_data.outputs.url,
                                              tmp_name)
                    output_map = _literal_models.LiteralMap.from_flyte_idl(
                        _common_utils.load_proto_from_file(
                            _literals_pb2.LiteralMap, tmp_name))
            else:
                output_map = _literal_models.LiteralMap({})

            self._outputs = _type_helpers.unpack_literal_map_to_sdk_python_std(
                output_map)
        return self._outputs
예제 #20
0
    def outputs(self) -> Dict[str, Any]:
        """
        Returns the outputs of the task execution, if available, in the standard Python format that is produced by
        the type engine.

        :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error.
        """
        if not self.is_complete:
            raise _user_exceptions.FlyteAssertion(
                "Please what until the task execution has completed before requesting the outputs."
            )
        if self.error:
            raise _user_exceptions.FlyteAssertion(
                "Outputs could not be found because the execution ended in failure."
            )

        if self._outputs is None:
            client = _flyte_engine.get_client()
            execution_data = client.get_task_execution_data(self.id)

            # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned.
            output_map = _literal_models.LiteralMap({})
            if bool(execution_data.full_outputs.literals):
                output_map = execution_data.full_outputs
            elif execution_data.outputs.bytes > 0:
                with _common_utils.AutoDeletingTempDir() as t:
                    tmp_name = _os.path.join(t.name, "outputs.pb")
                    _data_proxy.Data.get_data(execution_data.outputs.url,
                                              tmp_name)
                    output_map = _literal_models.LiteralMap.from_flyte_idl(
                        _common_utils.load_proto_from_file(
                            _literals_pb2.LiteralMap, tmp_name))

            # TODO: need to convert flyte literals to python types. For now just use literals
            # self._outputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=output_map)
            self._outputs = output_map
        return self._outputs
예제 #21
0
    def inputs(self) -> Dict[str, Any]:
        """
        Returns the inputs to the execution in the standard python format as dictated by the type engine.
        """
        if self._inputs is None:
            client = _flyte_engine.get_client()
            execution_data = client.get_execution_data(self.id)

            # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned.
            input_map: LiteralMap = _literal_models.LiteralMap({})
            if bool(execution_data.full_inputs.literals):
                input_map = execution_data.full_inputs
            elif execution_data.inputs.bytes > 0:
                with _common_utils.AutoDeletingTempDir() as tmp_dir:
                    tmp_name = _os.path.join(tmp_dir.name, "inputs.pb")
                    _data_proxy.Data.get_data(execution_data.inputs.url,
                                              tmp_name)
                    input_map = _literal_models.LiteralMap.from_flyte_idl(
                        _common_utils.load_proto_from_file(
                            _literals_pb2.Literalmap, tmp_name))
            # TODO: need to convert flyte literals to python types. For now just use literals
            # self._inputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=input_map)
            self._inputs = input_map
        return self._inputs
예제 #22
0
    def inputs(self):
        """
        Returns the inputs to the execution in the standard Python format as dictated by the type engine.
        :rtype: dict[Text, T]
        """
        if self._inputs is None:
            client = _flyte_engine.get_client()
            execution_data = client.get_node_execution_data(self.id)

            # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned.
            if bool(execution_data.full_inputs.literals):
                input_map = execution_data.full_inputs
            elif execution_data.inputs.bytes > 0:
                with _common_utils.AutoDeletingTempDir() as t:
                    tmp_name = _os.path.join(t.name, "inputs.pb")
                    _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name)
                    input_map = _literal_models.LiteralMap.from_flyte_idl(
                        _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name)
                    )
            else:
                input_map = _literal_models.LiteralMap({})

            self._inputs = _type_helpers.unpack_literal_map_to_sdk_python_std(input_map)
        return self._inputs
예제 #23
0
def construct_literal_map_from_variable_map(variable_dict, text_args):
    """
    This function produces a map of Literals to use when creating an execution.  It reads the required values from
    a Variable map (presumably obtained from a launch plan), and then fills in the necessary inputs
    from the click args. Click args will be strings, which will be parsed into their SDK types with each
    SDK type's parse string method.

    :param dict[Text, flytekit.models.interface.Variable] variable_dict:
    :param dict[Text, Text] text_args:
    :rtype: flytekit.models.literals.LiteralMap
    """
    inputs = {}

    for var_name, variable in _six.iteritems(variable_dict):
        # Check to see if it's passed from click
        # If this is an input that has a default from the LP, it should've already been parsed into a string,
        # and inserted into the default for this option, so it should still be here.
        if var_name in text_args and text_args[var_name] is not None:
            # the SDK type is also available from the sdk workflow object's saved user inputs but let's
            # derive it here from the interface to be more rigorous.
            sdk_type = _get_sdk_type_from_literal_type(variable.type)
            inputs[var_name] = sdk_type.from_string(text_args[var_name])

    return _literals.LiteralMap(literals=inputs)
예제 #24
0
    def get_default_launch_plan(ctx: FlyteContext, workflow: _annotated_workflow.WorkflowBase) -> LaunchPlan:
        """
        Users should probably call the get_or_create function defined below instead. A default launch plan is the one
        that will just pick up whatever default values are defined in the workflow function signature (if any) and
        use the default auth information supplied during serialization, with no notifications or schedules.

        :param ctx: This is not flytekit.current_context(). This is an internal context object. Users familiar with
          flytekit should feel free to use this however.
        :param workflow: The workflow to create a launch plan for.
        """
        if workflow.name in LaunchPlan.CACHE:
            return LaunchPlan.CACHE[workflow.name]

        parameter_map = transform_inputs_to_parameters(ctx, workflow.python_interface)

        lp = LaunchPlan(
            name=workflow.name,
            workflow=workflow,
            parameters=parameter_map,
            fixed_inputs=_literal_models.LiteralMap(literals={}),
        )

        LaunchPlan.CACHE[workflow.name] = lp
        return lp
예제 #25
0
import pytest

from flytekit.models import common as _common_models
from flytekit.models import execution as _execution
from flytekit.models import literals as _literals
from flytekit.models.core import execution as _core_exec
from flytekit.models.core import identifier as _identifier
from tests.flytekit.common import parameterizers as _parameterizers

_INPUT_MAP = _literals.LiteralMap({
    "a":
    _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(
        integer=1)))
})
_OUTPUT_MAP = _literals.LiteralMap({
    "b":
    _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(
        integer=2)))
})


def test_execution_metadata():
    obj = _execution.ExecutionMetadata(
        _execution.ExecutionMetadata.ExecutionMode.MANUAL, "tester", 1)
    assert obj.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL
    assert obj.principal == "tester"
    assert obj.nesting == 1
    obj2 = _execution.ExecutionMetadata.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL
    assert obj2.principal == "tester"
예제 #26
0
    def dispatch_execute(
        self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
    ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
        """
        This function is mostly copied from the base PythonTask, but differs in that we have to infer the Python
        interface before executing. Also, we refer to ``self.task_template`` rather than just ``self`` like in task
        classes that derive from the base ``PythonTask``.
        """
        # Invoked before the task is executed
        new_user_params = self.pre_execute(ctx.user_space_params)

        # Create another execution context with the new user params, but let's keep the same working dir
        with ctx.new_execution_context(
                mode=ctx.execution_state.mode,
                execution_params=new_user_params,
                working_dir=ctx.execution_state.working_dir,
        ) as exec_ctx:
            # Added: Have to reverse the Python interface from the task template Flyte interface
            # See docstring for more details.
            guessed_python_input_types = TypeEngine.guess_python_types(
                self.task_template.interface.inputs)
            native_inputs = TypeEngine.literal_map_to_kwargs(
                exec_ctx, input_literal_map, guessed_python_input_types)

            logger.info(
                f"Invoking FlyteTask executor {self.task_template.id.name} with inputs: {native_inputs}"
            )
            try:
                native_outputs = self.execute(**native_inputs)
            except Exception as e:
                logger.exception(f"Exception when executing {e}")
                raise e

            logger.info(
                f"Task executed successfully in user level, outputs: {native_outputs}"
            )
            # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
            # bubbled up to be handled at the callee layer.
            native_outputs = self.post_execute(new_user_params, native_outputs)

            # Short circuit the translation to literal map because what's returned may be a dj spec (or an
            # already-constructed LiteralMap if the dynamic task was a no-op), not python native values
            if isinstance(native_outputs,
                          _literal_models.LiteralMap) or isinstance(
                              native_outputs, _dynamic_job.DynamicJobSpec):
                return native_outputs

            expected_output_names = list(
                self.task_template.interface.outputs.keys())
            if len(expected_output_names) == 1:
                # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
                # length one. That convention is used for naming outputs - and single-length-NamedTuples are
                # particularly troublesome but elegant handling of them is not a high priority
                # Again, we're using the output_tuple_name as a proxy.
                # Deleted some stuff
                native_outputs_as_map = {
                    expected_output_names[0]: native_outputs
                }
            elif len(expected_output_names) == 0:
                native_outputs_as_map = {}
            else:
                native_outputs_as_map = {
                    expected_output_names[i]: native_outputs[i]
                    for i, _ in enumerate(native_outputs)
                }

            # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
            # built into the IDL that all the values of a literal map are of the same type.
            literals = {}
            for k, v in native_outputs_as_map.items():
                literal_type = self.task_template.interface.outputs[k].type
                py_type = type(v)

                if isinstance(v, tuple):
                    raise AssertionError(
                        f"Output({k}) in task{self.task_template.id.name} received a tuple {v}, instead of {py_type}"
                    )
                try:
                    literals[k] = TypeEngine.to_literal(
                        exec_ctx, v, py_type, literal_type)
                except Exception as e:
                    raise AssertionError(
                        f"failed to convert return value for var {k}") from e

            outputs_literal_map = _literal_models.LiteralMap(literals=literals)
            # After the execute has been successfully completed
            return outputs_literal_map
예제 #27
0
    def compile_into_workflow(
        self, ctx: FlyteContext, task_function: Callable, **kwargs
    ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]:
        if not ctx.compilation_state:
            cs = ctx.new_compilation_state("dynamic")
        else:
            cs = ctx.compilation_state.with_params(prefix="dynamic")

        with FlyteContextManager.with_context(ctx.with_compilation_state(cs)):
            # TODO: Resolve circular import
            from flytekit.common.translator import get_serializable

            workflow_metadata = WorkflowMetadata(
                on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY)
            defaults = WorkflowMetadataDefaults(
                interruptible=self.metadata.interruptible if self.metadata.
                interruptible is not None else False)

            self._wf = PythonFunctionWorkflow(task_function,
                                              metadata=workflow_metadata,
                                              default_metadata=defaults)
            self._wf.compile(**kwargs)

            wf = self._wf
            model_entities = OrderedDict()
            # See comment on reference entity checking a bit down below in this function.
            # This is the only circular dependency between the translator.py module and the rest of the flytekit
            # authoring experience.
            workflow_spec: admin_workflow_models.WorkflowSpec = get_serializable(
                model_entities, ctx.serialization_settings, wf)

            # If no nodes were produced, let's just return the strict outputs
            if len(workflow_spec.template.nodes) == 0:
                return _literal_models.LiteralMap(
                    literals={
                        binding.var: binding.binding.to_literal_model()
                        for binding in workflow_spec.template.outputs
                    })

            # This is not great. The translator.py module is relied on here (see comment above) to get the tasks and
            # subworkflow definitions. However we want to ensure that reference tasks and reference sub workflows are
            # not used.
            # TODO: Replace None with a class.
            for value in model_entities.values():
                if value is None:
                    raise Exception(
                        "Reference tasks are not allowed in the dynamic - a network call is necessary "
                        "in order to retrieve the structure of the reference task."
                    )

            # Gather underlying TaskTemplates that get referenced. Launch plans are handled by propeller. Subworkflows
            # should already be in the workflow spec.
            tts = [
                v.template for v in model_entities.values()
                if isinstance(v, task_models.TaskSpec)
            ]

            if ctx.serialization_settings.should_fast_serialize():
                if (not ctx.execution_state
                        or not ctx.execution_state.additional_context
                        or not ctx.execution_state.additional_context.get(
                            "dynamic_addl_distro")):
                    raise AssertionError(
                        "Compilation for a dynamic workflow called in fast execution mode but no additional code "
                        "distribution could be retrieved")
                logger.warn(
                    f"ctx.execution_state.additional_context {ctx.execution_state.additional_context}"
                )
                for task_template in tts:
                    sanitized_args = []
                    for arg in task_template.container.args:
                        if arg == "{{ .remote_package_path }}":
                            sanitized_args.append(
                                ctx.execution_state.additional_context.get(
                                    "dynamic_addl_distro"))
                        elif arg == "{{ .dest_dir }}":
                            sanitized_args.append(
                                ctx.execution_state.additional_context.get(
                                    "dynamic_dest_dir", "."))
                        else:
                            sanitized_args.append(arg)
                    del task_template.container.args[:]
                    task_template.container.args.extend(sanitized_args)

            dj_spec = _dynamic_job.DynamicJobSpec(
                min_successes=len(workflow_spec.template.nodes),
                tasks=tts,
                nodes=workflow_spec.template.nodes,
                outputs=workflow_spec.template.outputs,
                subworkflows=workflow_spec.sub_workflows,
            )

            return dj_spec
예제 #28
0
    def _produce_dynamic_job_spec(self, context, inputs):
        """
        Runs user code and and produces future task nodes to run sub-tasks.
        :param context:
        :param flytekit.models.literals.LiteralMap literal_map inputs:
        :rtype: (_dynamic_job.DynamicJobSpec, dict[Text, flytekit.models.common.FlyteIdlEntity])
        """
        inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std(
            inputs, {
                k: _type_helpers.get_sdk_type_from_literal_type(v.type)
                for k, v in _six.iteritems(self.interface.inputs)
            })
        outputs_dict = {
            name: PromiseOutputReference(
                _type_helpers.get_sdk_type_from_literal_type(variable.type))
            for name, variable in _six.iteritems(self.interface.outputs)
        }

        inputs_dict.update(outputs_dict)
        yielded_sub_tasks = [
            sub_task
            for sub_task in super(SdkDynamicTask, self)._execute_user_code(
                context, inputs_dict) or []
        ]

        upstream_nodes = list()
        output_bindings = [
            _literal_models.Binding(
                var=name,
                binding=_interface.BindingData.from_python_std(
                    b.sdk_type.to_flyte_literal_type(),
                    b.raw_value,
                    upstream_nodes=upstream_nodes))
            for name, b in _six.iteritems(outputs_dict)
        ]
        upstream_nodes = set(upstream_nodes)

        generated_files = {}
        # Keeping future-tasks in original order. We don't use upstream_nodes exclusively because the parent task can
        # yield sub-tasks that it never uses to produce final outputs but they need to execute nevertheless.
        array_job_index = {}
        tasks = []
        nodes = []
        visited_nodes = set()
        generated_ids = {}
        effective_failure_ratio = self._allowed_failure_ratio or 0.0
        for sub_task_node in _itertools.chain(yielded_sub_tasks,
                                              upstream_nodes):
            if sub_task_node in visited_nodes:
                continue
            visited_nodes.add(sub_task_node)

            # Generate an id that's unique in the document (if the same task is used multiple times with
            # different resources, executable_sdk_object.id will be the same but generated node_ids should not
            # be.
            safe_task_id = _six.text_type(
                sub_task_node.executable_sdk_object.id)
            if safe_task_id in generated_ids:
                new_count = generated_ids[
                    safe_task_id] = generated_ids[safe_task_id] + 1
            else:
                new_count = generated_ids[safe_task_id] = 0
            unique_node_id = _dnsify("{}-{}".format(safe_task_id, new_count))

            # If the task can run as an array job, group its instances together. Otherwise, keep each invocation as a
            # separate node.
            if SdkDynamicTask._can_run_as_array(
                    sub_task_node.executable_sdk_object.type):
                if sub_task_node.executable_sdk_object in array_job_index:
                    array_job, node = array_job_index[
                        sub_task_node.executable_sdk_object]
                    array_job.size += 1
                    array_job.min_successes = int(
                        math.ceil(
                            (1 - effective_failure_ratio) * array_job.size))
                else:
                    array_job = self._create_array_job(
                        inputs_prefix=unique_node_id)
                    node = sub_task_node.assign_id_and_return(unique_node_id)
                    array_job_index[sub_task_node.executable_sdk_object] = (
                        array_job, node)

                node_index = _six.text_type(array_job.size - 1)
                for k, node_output in _six.iteritems(sub_task_node.outputs):
                    if not node_output.sdk_node.id:
                        node_output.sdk_node.assign_id_and_return(node.id)
                    node_output.var = "[{}].{}".format(node_index,
                                                       node_output.var)

                # Upload inputs to working directory under /array_job.input_ref/<index>/inputs.pb
                input_path = _os.path.join(node.id, node_index,
                                           _constants.INPUT_FILE_NAME)
                generated_files[input_path] = _literal_models.LiteralMap(
                    literals={
                        binding.var: binding.binding.to_literal_model()
                        for binding in sub_task_node.inputs
                    })
            else:
                node = sub_task_node.assign_id_and_return(unique_node_id)

                tasks.append(sub_task_node.executable_sdk_object)
                nodes.append(node)

                for k, node_output in _six.iteritems(sub_task_node.outputs):
                    if not node_output.sdk_node.id:
                        node_output.sdk_node.assign_id_and_return(node.id)

                # Upload inputs to working directory under /array_job.input_ref/inputs.pb
                input_path = _os.path.join(node.id, _constants.INPUT_FILE_NAME)
                generated_files[input_path] = _literal_models.LiteralMap(
                    literals={
                        binding.var: binding.binding.to_literal_model()
                        for binding in sub_task_node.inputs
                    })

        # assign custom field to the ArrayJob properties computed.
        for task, (array_job, _) in _six.iteritems(array_job_index):
            # TODO: Reconstruct task template object instead of modifying an existing one?
            tasks.append(
                task.assign_custom_and_return(
                    array_job.to_dict()).assign_type_and_return(
                        _constants.SdkTaskType.CONTAINER_ARRAY_TASK))

        # min_successes is absolute, it's computed as the reverse of allowed_failure_ratio and multiplied by the
        # total length of tasks to get an absolute count.
        nodes.extend([
            array_job_node for (_, array_job_node) in array_job_index.values()
        ])
        dynamic_job_spec = _dynamic_job.DynamicJobSpec(
            min_successes=len(nodes),
            tasks=tasks,
            nodes=nodes,
            outputs=output_bindings,
            subworkflows=[])

        return dynamic_job_spec, generated_files
예제 #29
0
    def dispatch_execute(
        self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
    ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
        """
        This method translates Flyte's Type system based input values and invokes the actual call to the executor
        This method is also invoked during runtime.

        * ``VoidPromise`` is returned in the case when the task itself declares no outputs.
        * ``Literal Map`` is returned when the task returns either one more outputs in the declaration. Individual outputs
          may be none
        * ``DynamicJobSpec`` is returned when a dynamic workflow is executed
        """

        # Invoked before the task is executed
        new_user_params = self.pre_execute(ctx.user_space_params)

        # Create another execution context with the new user params, but let's keep the same working dir
        with ctx.new_execution_context(
                mode=ctx.execution_state.mode,
                execution_params=new_user_params,
                working_dir=ctx.execution_state.working_dir,
        ) as exec_ctx:
            # TODO We could support default values here too - but not part of the plan right now
            # Translate the input literals to Python native
            native_inputs = TypeEngine.literal_map_to_kwargs(
                exec_ctx, input_literal_map, self.python_interface.inputs)

            # TODO: Logger should auto inject the current context information to indicate if the task is running within
            #   a workflow or a subworkflow etc
            logger.info(f"Invoking {self.name} with inputs: {native_inputs}")
            try:
                native_outputs = self.execute(**native_inputs)
            except Exception as e:
                logger.exception(f"Exception when executing {e}")
                raise e

            logger.info(
                f"Task executed successfully in user level, outputs: {native_outputs}"
            )
            # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
            # bubbled up to be handled at the callee layer.
            native_outputs = self.post_execute(new_user_params, native_outputs)

            # Short circuit the translation to literal map because what's returned may be a dj spec (or an
            # already-constructed LiteralMap if the dynamic task was a no-op), not python native values
            if isinstance(native_outputs,
                          _literal_models.LiteralMap) or isinstance(
                              native_outputs, _dynamic_job.DynamicJobSpec):
                return native_outputs

            expected_output_names = list(self._outputs_interface.keys())
            if len(expected_output_names) == 1:
                # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
                # length one. That convention is used for naming outputs - and single-length-NamedTuples are
                # particularly troublesome but elegant handling of them is not a high priority
                # Again, we're using the output_tuple_name as a proxy.
                if self.python_interface.output_tuple_name and isinstance(
                        native_outputs, tuple):
                    native_outputs_as_map = {
                        expected_output_names[0]: native_outputs[0]
                    }
                else:
                    native_outputs_as_map = {
                        expected_output_names[0]: native_outputs
                    }
            elif len(expected_output_names) == 0:
                native_outputs_as_map = {}
            else:
                native_outputs_as_map = {
                    expected_output_names[i]: native_outputs[i]
                    for i, _ in enumerate(native_outputs)
                }

            # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
            # built into the IDL that all the values of a literal map are of the same type.
            literals = {}
            for k, v in native_outputs_as_map.items():
                literal_type = self._outputs_interface[k].type
                py_type = self.get_type_for_output_var(k, v)

                if isinstance(v, tuple):
                    raise AssertionError(
                        f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}"
                    )
                try:
                    literals[k] = TypeEngine.to_literal(
                        exec_ctx, v, py_type, literal_type)
                except Exception as e:
                    raise AssertionError(
                        f"failed to convert return value for var {k}") from e

            outputs_literal_map = _literal_models.LiteralMap(literals=literals)
            # After the execute has been successfully completed
            return outputs_literal_map
예제 #30
0
def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str, output_prefix: str):
    """
    Dispatches execute to PythonTask
        Step1: Download inputs and load into a literal map
        Step2: Invoke task - dispatch_execute
        Step3:
            a: [Optional] Record outputs to output_prefix
            b: OR if IgnoreOutputs is raised, then ignore uploading outputs
            c: OR if an unhandled exception is retrieved - record it as an errors.pb
    """
    output_file_dict = {}
    try:
        # Step1
        local_inputs_file = _os.path.join(ctx.execution_state.working_dir, "inputs.pb")
        ctx.file_access.get_data(inputs_path, local_inputs_file)
        input_proto = _utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file)
        idl_input_literals = _literal_models.LiteralMap.from_flyte_idl(input_proto)
        # Step2
        outputs = task_def.dispatch_execute(ctx, idl_input_literals)
        # Step3a
        if isinstance(outputs, VoidPromise):
            _logging.getLogger().warning("Task produces no outputs")
            output_file_dict = {_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})}
        elif isinstance(outputs, _literal_models.LiteralMap):
            output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs}
        elif isinstance(outputs, _dynamic_job.DynamicJobSpec):
            output_file_dict = {_constants.FUTURES_FILE_NAME: outputs}
        else:
            _logging.getLogger().error(f"SystemError: received unknown outputs from task {outputs}")
            output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
                _error_models.ContainerError(
                    "UNKNOWN_OUTPUT",
                    f"Type of output received not handled {type(outputs)} outputs: {outputs}",
                    _error_models.ContainerError.Kind.RECOVERABLE,
                )
            )
    except _scoped_exceptions.FlyteScopedException as e:
        _logging.error("!! Begin Error Captured by Flyte !!")
        output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
            _error_models.ContainerError(e.error_code, e.verbose_message, e.kind)
        )
        _logging.error(e.verbose_message)
        _logging.error("!! End Error Captured by Flyte !!")
    except Exception as e:
        if isinstance(e, IgnoreOutputs):
            # Step 3b
            _logging.warning(f"IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}")
            return
        # Step 3c
        _logging.error(f"Exception when executing task {task_def.name}, reason {str(e)}")
        _logging.error("!! Begin Unknown System Error Captured by Flyte !!")
        exc_str = _traceback.format_exc()
        output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
            _error_models.ContainerError("SYSTEM:Unknown", exc_str, _error_models.ContainerError.Kind.RECOVERABLE,)
        )
        _logging.error(exc_str)
        _logging.error("!! End Error Captured by Flyte !!")

    for k, v in output_file_dict.items():
        _common_utils.write_proto_to_file(v.to_flyte_idl(), _os.path.join(ctx.execution_state.engine_dir, k))

    ctx.file_access.upload_directory(ctx.execution_state.engine_dir, output_prefix)
    _logging.info(f"Engine folder written successfully to the output prefix {output_prefix}")