示例#1
0
def test_node_task_with_inputs():
    nm = _get_sample_node_metadata()
    task = _workflow.TaskNode(reference_id=_generic_id)
    bd = _literals.BindingData(scalar=_literals.Scalar(
        primitive=_literals.Primitive(integer=5)))
    bd2 = _literals.BindingData(scalar=_literals.Scalar(
        primitive=_literals.Primitive(integer=99)))
    binding = _literals.Binding(var="myvar", binding=bd)
    binding2 = _literals.Binding(var="myothervar", binding=bd2)

    obj = _workflow.Node(
        id="some:node:id",
        metadata=nm,
        inputs=[binding, binding2],
        upstream_node_ids=[],
        output_aliases=[],
        task_node=task,
    )
    assert obj.target == task
    assert obj.id == "some:node:id"
    assert obj.metadata == nm
    assert len(obj.inputs) == 2
    assert obj.inputs[0] == binding

    obj2 = _workflow.Node.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.target == task
    assert obj2.id == "some:node:id"
    assert obj2.metadata == nm
    assert len(obj2.inputs) == 2
    assert obj2.inputs[1] == binding2
示例#2
0
def test_branch_node():
    nm = _get_sample_node_metadata()
    task = _workflow.TaskNode(reference_id=_generic_id)
    bd = _literals.BindingData(scalar=_literals.Scalar(
        primitive=_literals.Primitive(integer=5)))
    bd2 = _literals.BindingData(scalar=_literals.Scalar(
        primitive=_literals.Primitive(integer=99)))
    binding = _literals.Binding(var='myvar', binding=bd)
    binding2 = _literals.Binding(var='myothervar', binding=bd2)

    obj = _workflow.Node(id='some:node:id',
                         metadata=nm,
                         inputs=[binding, binding2],
                         upstream_node_ids=[],
                         output_aliases=[],
                         task_node=task)

    bn = _workflow.BranchNode(
        _workflow.IfElseBlock(
            case=_workflow.IfBlock(condition=_condition.BooleanExpression(
                comparison=_condition.ComparisonExpression(
                    _condition.ComparisonExpression.Operator.EQ,
                    _condition.Operand(primitive=_literals.Primitive(
                        integer=5)),
                    _condition.Operand(primitive=_literals.Primitive(
                        integer=2)))),
                                   then_node=obj),
            other=[
                _workflow.IfBlock(condition=_condition.BooleanExpression(
                    conjunction=_condition.ConjunctionExpression(
                        _condition.ConjunctionExpression.LogicalOperator.AND,
                        _condition.BooleanExpression(
                            comparison=_condition.ComparisonExpression(
                                _condition.ComparisonExpression.Operator.EQ,
                                _condition.Operand(
                                    primitive=_literals.Primitive(integer=5)),
                                _condition.Operand(
                                    primitive=_literals.Primitive(
                                        integer=2)))),
                        _condition.BooleanExpression(
                            comparison=_condition.ComparisonExpression(
                                _condition.ComparisonExpression.Operator.EQ,
                                _condition.Operand(
                                    primitive=_literals.Primitive(integer=5)),
                                _condition.Operand(
                                    primitive=_literals.Primitive(
                                        integer=2)))))),
                                  then_node=obj)
            ],
            else_node=obj))
示例#3
0
def test_non_system_nodes():
    @inputs(a=primitives.Integer)
    @outputs(b=primitives.Integer)
    @python_task()
    def my_task(wf_params, a, b):
        b.set(a + 1)

    my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                         "project", "domain", "my_task",
                                         "version")

    required_input = promise.Input("required", primitives.Integer)

    n1 = my_task(a=required_input).assign_id_and_return("n1")

    n_start = nodes.SdkNode(
        "start-node",
        [],
        [
            _literals.Binding(
                "a",
                interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        None,
        sdk_task=my_task,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )

    non_system_nodes = workflow.SdkWorkflow.get_non_system_nodes([n1, n_start])
    assert len(non_system_nodes) == 1
    assert non_system_nodes[0].id == "n1"
示例#4
0
def test_blah():
    @inputs(a=primitives.Integer)
    @outputs(b=primitives.Integer)
    @python_task()
    def my_task(wf_params, a, b):
        b.set(a + 1)

    my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_task', 'version')

    required_input = promise.Input('required', primitives.Integer)

    n1 = my_task(a=required_input).assign_id_and_return('n1')

    n_start = nodes.SdkNode(
        'start-node',
        [],
        [
            _literals.Binding(
                'a',
                interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3)
            )
        ],
        None,
        sdk_task=my_task,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None
    )

    non_system_nodes = workflow.SdkWorkflow.get_non_system_nodes([n1, n_start])
    assert len(non_system_nodes) == 1
    assert non_system_nodes[0].id == 'n1'
示例#5
0
    def __init__(self,
                 inputs,
                 outputs,
                 nodes,
                 id=None,
                 metadata=None,
                 metadata_defaults=None,
                 interface=None,
                 output_bindings=None):
        """
        :param list[flytekit.common.promise.Input] inputs:
        :param list[Output] outputs:
        :param list[flytekit.common.nodes.SdkNode] nodes:
        :param flytekit.models.core.identifier.Identifier id: This is an autogenerated id by the system. The id is
            globally unique across Flyte.
        :param WorkflowMetadata metadata: This contains information on how to run the workflow.
        :param flytekit.models.interface.TypedInterface interface: Defines a strongly typed interface for the
            Workflow (inputs, outputs).  This can include some optional parameters.
        :param list[flytekit.models.literals.Binding] output_bindings: A list of output bindings that specify how to construct
            workflow outputs. Bindings can pull node outputs or specify literals. All workflow outputs specified in
            the interface field must be bound
            in order for the workflow to be validated. A workflow has an implicit dependency on all of its nodes
            to execute successfully in order to bind final outputs.

        """
        for n in nodes:
            for upstream in n.upstream_nodes:
                if upstream.id is None:
                    raise _user_exceptions.FlyteAssertion(
                        "Some nodes contained in the workflow were not found in the workflow description.  Please "
                        "ensure all nodes are either assigned to attributes within the class or an element in a "
                        "list, dict, or tuple which is stored as an attribute in the class."
                    )

        # Allow overrides if specified for all the arguments to the parent class constructor
        id = id if id is not None else _identifier.Identifier(
            _identifier_model.ResourceType.WORKFLOW,
            _internal_config.PROJECT.get(), _internal_config.DOMAIN.get(),
            _uuid.uuid4().hex, _internal_config.VERSION.get())
        metadata = metadata if metadata is not None else _workflow_models.WorkflowMetadata(
        )

        interface = interface if interface is not None else _interface.TypedInterface(
            {v.name: v.var
             for v in inputs}, {v.name: v.var
                                for v in outputs})

        output_bindings = output_bindings if output_bindings is not None else \
            [_literal_models.Binding(v.name, v.binding_data) for v in outputs]

        super(SdkWorkflow, self).__init__(
            id=id,
            metadata=metadata,
            metadata_defaults=_workflow_models.WorkflowMetadataDefaults(),
            interface=interface,
            nodes=nodes,
            outputs=output_bindings,
        )
        self._user_inputs = inputs
        self._upstream_entities = set(n.executable_sdk_object for n in nodes)
示例#6
0
    def __init__(self, inputs, outputs, nodes):
        """
        :param list[flytekit.common.promise.Input] inputs:
        :param list[Output] outputs:
        :param list[flytekit.common.nodes.SdkNode] nodes:
        """
        for n in nodes:
            for upstream in n.upstream_nodes:
                if upstream.id is None:
                    raise _user_exceptions.FlyteAssertion(
                        "Some nodes contained in the workflow were not found in the workflow description.  Please "
                        "ensure all nodes are either assigned to attributes within the class or an element in a "
                        "list, dict, or tuple which is stored as an attribute in the class."
                    )

        super(SdkWorkflow, self).__init__(
            id=_identifier.Identifier(_identifier_model.ResourceType.WORKFLOW,
                                      _internal_config.PROJECT.get(),
                                      _internal_config.DOMAIN.get(),
                                      _uuid.uuid4().hex,
                                      _internal_config.VERSION.get()),
            metadata=_workflow_models.WorkflowMetadata(),
            interface=_interface.TypedInterface(
                {v.name: v.var
                 for v in inputs}, {v.name: v.var
                                    for v in outputs}),
            nodes=nodes,
            outputs=[
                _literal_models.Binding(v.name, v.binding_data)
                for v in outputs
            ],
        )
        self._user_inputs = inputs
        self._upstream_entities = set(n.executable_sdk_object for n in nodes)
示例#7
0
    def construct_from_class_definition(
        cls,
        inputs: List[_promise.Input],
        outputs: List[Output],
        nodes: List[_nodes.SdkNode],
        metadata: _workflow_models.WorkflowMetadata = None,
        metadata_defaults: _workflow_models.WorkflowMetadataDefaults = None,
        disable_default_launch_plan: bool = False,
    ) -> "SdkRunnableWorkflow":
        """
        This constructor is here to provide backwards-compatibility for class-defined Workflows

        :param list[flytekit.common.promise.Input] inputs:
        :param list[Output] outputs:
        :param list[flytekit.common.nodes.SdkNode] nodes:
        :param WorkflowMetadata metadata: This contains information on how to run the workflow.
        :param flytekit.models.core.workflow.WorkflowMetadataDefaults metadata_defaults: Defaults to be passed
            to nodes contained within workflow.
        :param bool disable_default_launch_plan: Determines whether to create a default launch plan for the workflow or not.

        :rtype: SdkRunnableWorkflow
        """
        for n in nodes:
            for upstream in n.upstream_nodes:
                if upstream.id is None:
                    raise _user_exceptions.FlyteAssertion(
                        "Some nodes contained in the workflow were not found in the workflow description.  Please "
                        "ensure all nodes are either assigned to attributes within the class or an element in a "
                        "list, dict, or tuple which is stored as an attribute in the class."
                    )

        id = _identifier.Identifier(
            _identifier_model.ResourceType.WORKFLOW,
            _internal_config.PROJECT.get(),
            _internal_config.DOMAIN.get(),
            _uuid.uuid4().hex,
            _internal_config.VERSION.get(),
        )
        interface = _interface.TypedInterface({v.name: v.var
                                               for v in inputs},
                                              {v.name: v.var
                                               for v in outputs})

        output_bindings = [
            _literal_models.Binding(v.name, v.binding_data) for v in outputs
        ]

        return cls(
            inputs=inputs,
            nodes=nodes,
            interface=interface,
            output_bindings=output_bindings,
            id=id,
            metadata=metadata,
            metadata_defaults=metadata_defaults,
            disable_default_launch_plan=disable_default_launch_plan,
        )
示例#8
0
def binding_from_python_std(
    ctx: _flyte_context.FlyteContext,
    var_name: str,
    expected_literal_type: _type_models.LiteralType,
    t_value: typing.Any,
    t_value_type: type,
) -> _literals_models.Binding:
    binding_data = binding_data_from_python_std(ctx, expected_literal_type,
                                                t_value, t_value_type)
    return _literals_models.Binding(var=var_name, binding=binding_data)
示例#9
0
    def _produce_dynamic_job_spec(self, context, inputs):
        """
        Runs user code and and produces future task nodes to run sub-tasks.
        :param context:
        :param flytekit.models.literals.LiteralMap literal_map inputs:
        :rtype: flytekit.models.dynamic_job.DynamicJobSpec
        """
        inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std(
            inputs,
            {
                k: _type_helpers.get_sdk_type_from_literal_type(v.type)
                for k, v in _six.iteritems(self.interface.inputs)
            },
        )
        outputs_dict = {
            name: _task_output.OutputReference(
                _type_helpers.get_sdk_type_from_literal_type(variable.type))
            for name, variable in _six.iteritems(self.interface.outputs)
        }

        # Add outputs to inputs
        inputs_dict.update(outputs_dict)

        nodes = []
        tasks = []
        # One node per query
        generated_queries = self._generate_plugin_objects(context, inputs_dict)

        # Create output bindings always - this has to happen after user code has run
        output_bindings = [
            _literal_models.Binding(
                var=name,
                binding=_interface.BindingData.from_python_std(
                    b.sdk_type.to_flyte_literal_type(), b.value),
            ) for name, b in _six.iteritems(outputs_dict)
        ]

        i = 0
        for quboleHiveJob in generated_queries:
            hive_job_node = _create_hive_job_node("HiveQuery_{}".format(i),
                                                  quboleHiveJob.to_flyte_idl(),
                                                  self.metadata)
            nodes.append(hive_job_node)
            tasks.append(hive_job_node.executable_sdk_object)
            i += 1

        dynamic_job_spec = _dynamic_job.DynamicJobSpec(
            min_successes=len(nodes),
            tasks=tasks,
            nodes=nodes,
            outputs=output_bindings,
            subworkflows=[],
        )

        return dynamic_job_spec
示例#10
0
    def _produce_dynamic_job_spec(self, context, inputs):
        """
        Runs user code and and produces future task nodes to run sub-tasks.
        :param context:
        :param flytekit.models.literals.LiteralMap literal_map inputs:
        :rtype: flytekit.models.dynamic_job.DynamicJobSpec
        """
        inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std(
            inputs, {
                k: _type_helpers.get_sdk_type_from_literal_type(v.type)
                for k, v in _six.iteritems(self.interface.inputs)
            })
        outputs_dict = {
            name: _task_output.OutputReference(
                _type_helpers.get_sdk_type_from_literal_type(variable.type))
            for name, variable in _six.iteritems(self.interface.outputs)
        }

        # Add outputs to inputs
        inputs_dict.update(outputs_dict)

        # Note: Today a hive task corresponds to a dynamic job spec with one node, which contains multiple
        # queries. We may change this in future.
        nodes = []
        tasks = []
        generated_queries = self._generate_hive_queries(context, inputs_dict)

        # Create output bindings always - this has to happen after user code has run
        output_bindings = [
            _literal_models.Binding(
                var=name,
                binding=_interface.BindingData.from_python_std(
                    b.sdk_type.to_flyte_literal_type(), b.value))
            for name, b in _six.iteritems(outputs_dict)
        ]

        if len(generated_queries.query_collection.queries) > 0:
            hive_job_node = _create_hive_job_node(
                "HiveQueries", generated_queries.to_flyte_idl(), self.metadata)
            nodes.append(hive_job_node)
            tasks.append(hive_job_node.executable_sdk_object)

        dynamic_job_spec = _dynamic_job.DynamicJobSpec(
            min_successes=len(
                nodes
            ),  # At most we only have one node for now, see above comment
            tasks=tasks,
            nodes=nodes,
            outputs=output_bindings,
            subworkflows=[])

        return dynamic_job_spec
示例#11
0
def test_sdk_node_from_lp():
    @_tasks.inputs(a=_types.Types.Integer)
    @_tasks.outputs(b=_types.Types.Integer)
    @_tasks.python_task()
    def testy_test(wf_params, a, b):
        pass

    @_workflow.workflow_class
    class test_workflow(object):
        a = _workflow.Input(_types.Types.Integer)
        test = testy_test(a=a)
        b = _workflow.Output(test.outputs.b, sdk_type=_types.Types.Integer)

    lp = test_workflow.create_launch_plan()

    n1 = _nodes.SdkNode(
        "n1",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_launch_plan=lp,
    )

    assert n1.id == "n1"
    assert len(n1.inputs) == 1
    assert n1.inputs[0].var == "a"
    assert n1.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n1.outputs) == 1
    assert "b" in n1.outputs
    assert n1.outputs["b"].node_id == "n1"
    assert n1.outputs["b"].var == "b"
    assert n1.outputs["b"].sdk_node == n1
    assert n1.outputs["b"].sdk_type == _types.Types.Integer
    assert n1.metadata.name == "abc"
    assert n1.metadata.retries.retries == 3
    assert len(n1.upstream_nodes) == 0
    assert len(n1.upstream_node_ids) == 0
    assert len(n1.output_aliases) == 0
示例#12
0
def test_future_task_document(task):
    rs = _literals.RetryStrategy(0)
    nm = _workflow.NodeMetadata('node-name', _timedelta(minutes=10), rs)
    n = _workflow.Node(id="id",
                       metadata=nm,
                       inputs=[],
                       upstream_node_ids=[],
                       output_aliases=[],
                       task_node=_workflow.TaskNode(task.id))
    n.to_flyte_idl()
    doc = _dynamic_job.DynamicJobSpec(
        tasks=[task],
        nodes=[n],
        min_successes=1,
        outputs=[_literals.Binding("var", _literals.BindingData())],
        subworkflows=[])
    assert text_format.MessageToString(
        doc.to_flyte_idl()) == text_format.MessageToString(
            _dynamic_job.DynamicJobSpec.from_flyte_idl(
                doc.to_flyte_idl()).to_flyte_idl())
示例#13
0
    def create_bindings_for_inputs(self, map_of_bindings):
        """
        :param dict[Text, T] map_of_bindings:  This can be scalar primitives, it can be node output references,
            lists, etc..
        :rtype: (list[flytekit.models.literals.Binding], list[flytekit.common.nodes.SdkNode])
        :raises: flytekit.common.exceptions.user.FlyteAssertion
        """
        binding_data = dict()
        all_upstream_nodes = list()
        for k in sorted(self.inputs):
            var = self.inputs[k]
            if k not in map_of_bindings:
                raise _user_exceptions.FlyteAssertion(
                    "Input was not specified for: {} of type {}".format(
                        k, var.type))

            binding_data[k] = BindingData.from_python_std(
                var.type,
                map_of_bindings[k],
                upstream_nodes=all_upstream_nodes)

        extra_inputs = set(binding_data.keys()) ^ set(map_of_bindings.keys())
        if len(extra_inputs) > 0:
            raise _user_exceptions.FlyteAssertion(
                "Too many inputs were specified for the interface.  Extra inputs were: {}"
                .format(extra_inputs))

        seen_nodes = set()
        min_upstream = list()
        for n in all_upstream_nodes:
            if n not in seen_nodes:
                seen_nodes.add(n)
                min_upstream.append(n)

        return (
            [
                _literal_models.Binding(k, bd)
                for k, bd in _six.iteritems(binding_data)
            ],
            min_upstream,
        )
示例#14
0
def to_binding(p: Promise) -> _literals_models.Binding:
    return _literals_models.Binding(
        var=p.var, binding=_literals_models.BindingData(promise=p.ref))
示例#15
0
    def _produce_dynamic_job_spec(self, context, inputs):
        """
        Runs user code and and produces future task nodes to run sub-tasks.
        :param context:
        :param flytekit.models.literals.LiteralMap literal_map inputs:
        :rtype: (_dynamic_job.DynamicJobSpec, dict[Text, flytekit.models.common.FlyteIdlEntity])
        """
        inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std(
            inputs, {
                k: _type_helpers.get_sdk_type_from_literal_type(v.type)
                for k, v in _six.iteritems(self.interface.inputs)
            })
        outputs_dict = {
            name: PromiseOutputReference(
                _type_helpers.get_sdk_type_from_literal_type(variable.type))
            for name, variable in _six.iteritems(self.interface.outputs)
        }

        # Because users declare both inputs and outputs in their functions signatures, merge them together
        # before calling user code
        inputs_dict.update(outputs_dict)
        yielded_sub_tasks = [
            sub_task
            for sub_task in super(SdkDynamicTask, self)._execute_user_code(
                context, inputs_dict) or []
        ]

        upstream_nodes = list()
        output_bindings = [
            _literal_models.Binding(
                var=name,
                binding=_interface.BindingData.from_python_std(
                    b.sdk_type.to_flyte_literal_type(),
                    b.raw_value,
                    upstream_nodes=upstream_nodes))
            for name, b in _six.iteritems(outputs_dict)
        ]
        upstream_nodes = set(upstream_nodes)

        generated_files = {}
        # Keeping future-tasks in original order. We don't use upstream_nodes exclusively because the parent task can
        # yield sub-tasks that it never uses to produce final outputs but they need to execute nevertheless.
        array_job_index = {}
        tasks = set()
        nodes = []
        sub_workflows = set()
        visited_nodes = set()
        generated_ids = {}
        effective_failure_ratio = self._allowed_failure_ratio or 0.0

        # TODO: This function needs to be cleaned up.
        # The reason we chain these two together is because we allow users to not have to explicitly "yield" the
        # node. As long as the subtask/lp/subwf has an output that's referenced, it'll get picked up.
        for sub_task_node in _itertools.chain(yielded_sub_tasks,
                                              upstream_nodes):
            if sub_task_node in visited_nodes:
                continue
            visited_nodes.add(sub_task_node)
            executable = sub_task_node.executable_sdk_object

            # If the executable object that we're dealing with is registerable (ie, SdkRunnableLaunchPlan, SdkWorkflow
            # SdkTask, or SdkRunnableTask), then it should have the ability to give itself a name. After assigning
            # itself the name, also make sure the id is properly set according to current config values.
            if isinstance(executable, _registerable.RegisterableEntity):
                executable.auto_assign_name()
                executable._id = _identifier.Identifier(
                    executable.resource_type,
                    _internal_config.TASK_PROJECT.get()
                    or _internal_config.PROJECT.get(),
                    _internal_config.TASK_DOMAIN.get()
                    or _internal_config.DOMAIN.get(),
                    executable.platform_valid_name,
                    _internal_config.TASK_VERSION.get()
                    or _internal_config.VERSION.get())

            # Generate an id that's unique in the document (if the same task is used multiple times with
            # different resources, executable_sdk_object.id will be the same but generated node_ids should not
            # be.
            safe_task_id = _six.text_type(
                sub_task_node.executable_sdk_object.id)
            if safe_task_id in generated_ids:
                new_count = generated_ids[
                    safe_task_id] = generated_ids[safe_task_id] + 1
            else:
                new_count = generated_ids[safe_task_id] = 0
            unique_node_id = _dnsify("{}-{}".format(safe_task_id, new_count))

            # Handling case where the yielded node is launch plan
            if isinstance(sub_task_node.executable_sdk_object,
                          _launch_plan.SdkLaunchPlan):
                node = sub_task_node.assign_id_and_return(unique_node_id)
                _append_node(generated_files, node, nodes, sub_task_node)
            # Handling case where the yielded node is launching a sub-workflow
            elif isinstance(sub_task_node.executable_sdk_object,
                            _workflow.SdkWorkflow):
                node = sub_task_node.assign_id_and_return(unique_node_id)
                _append_node(generated_files, node, nodes, sub_task_node)
                # Add the workflow itself to the yielded sub-workflows
                sub_workflows.add(sub_task_node.executable_sdk_object)
                # Recursively discover statically defined upstream entities (tasks, wfs)
                SdkDynamicTask._add_upstream_entities(
                    sub_task_node.executable_sdk_object, sub_workflows, tasks)
            # Handling tasks
            else:
                # If the task can run as an array job, group its instances together. Otherwise, keep each
                # invocation as a separate node.
                if SdkDynamicTask._can_run_as_array(
                        sub_task_node.executable_sdk_object.type):
                    if sub_task_node.executable_sdk_object in array_job_index:
                        array_job, node = array_job_index[
                            sub_task_node.executable_sdk_object]
                        array_job.size += 1
                        array_job.min_successes = int(
                            math.ceil((1 - effective_failure_ratio) *
                                      array_job.size))
                    else:
                        array_job = self._create_array_job(
                            inputs_prefix=unique_node_id)
                        node = sub_task_node.assign_id_and_return(
                            unique_node_id)
                        array_job_index[
                            sub_task_node.executable_sdk_object] = (array_job,
                                                                    node)

                    node_index = _six.text_type(array_job.size - 1)
                    for k, node_output in _six.iteritems(
                            sub_task_node.outputs):
                        if not node_output.sdk_node.id:
                            node_output.sdk_node.assign_id_and_return(node.id)
                        node_output.var = "[{}].{}".format(
                            node_index, node_output.var)

                    # Upload inputs to working directory under /array_job.input_ref/<index>/inputs.pb
                    input_path = _os.path.join(node.id, node_index,
                                               _constants.INPUT_FILE_NAME)
                    generated_files[input_path] = _literal_models.LiteralMap(
                        literals={
                            binding.var: binding.binding.to_literal_model()
                            for binding in sub_task_node.inputs
                        })
                else:
                    node = sub_task_node.assign_id_and_return(unique_node_id)
                    tasks.add(sub_task_node.executable_sdk_object)
                    _append_node(generated_files, node, nodes, sub_task_node)

        # assign custom field to the ArrayJob properties computed.
        for task, (array_job, _) in _six.iteritems(array_job_index):
            # TODO: Reconstruct task template object instead of modifying an existing one?
            tasks.add(
                task.assign_custom_and_return(
                    array_job.to_dict()).assign_type_and_return(
                        _constants.SdkTaskType.CONTAINER_ARRAY_TASK))

        # min_successes is absolute, it's computed as the reverse of allowed_failure_ratio and multiplied by the
        # total length of tasks to get an absolute count.
        nodes.extend([
            array_job_node for (_, array_job_node) in array_job_index.values()
        ])
        dynamic_job_spec = _dynamic_job.DynamicJobSpec(
            min_successes=len(nodes),
            tasks=list(tasks),
            nodes=nodes,
            outputs=output_bindings,
            subworkflows=list(sub_workflows))

        return dynamic_job_spec, generated_files
示例#16
0
    def _produce_dynamic_job_spec(self, context, inputs):
        """
        Runs user code and and produces future task nodes to run sub-tasks.
        :param context:
        :param flytekit.models.literals.LiteralMap literal_map inputs:
        :rtype: (_dynamic_job.DynamicJobSpec, dict[Text, flytekit.models.common.FlyteIdlEntity])
        """
        inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std(
            inputs, {
                k: _type_helpers.get_sdk_type_from_literal_type(v.type)
                for k, v in _six.iteritems(self.interface.inputs)
            })
        outputs_dict = {
            name: PromiseOutputReference(
                _type_helpers.get_sdk_type_from_literal_type(variable.type))
            for name, variable in _six.iteritems(self.interface.outputs)
        }

        inputs_dict.update(outputs_dict)
        yielded_sub_tasks = [
            sub_task
            for sub_task in super(SdkDynamicTask, self)._execute_user_code(
                context, inputs_dict) or []
        ]

        upstream_nodes = list()
        output_bindings = [
            _literal_models.Binding(
                var=name,
                binding=_interface.BindingData.from_python_std(
                    b.sdk_type.to_flyte_literal_type(),
                    b.raw_value,
                    upstream_nodes=upstream_nodes))
            for name, b in _six.iteritems(outputs_dict)
        ]
        upstream_nodes = set(upstream_nodes)

        generated_files = {}
        # Keeping future-tasks in original order. We don't use upstream_nodes exclusively because the parent task can
        # yield sub-tasks that it never uses to produce final outputs but they need to execute nevertheless.
        array_job_index = {}
        tasks = []
        nodes = []
        visited_nodes = set()
        generated_ids = {}
        effective_failure_ratio = self._allowed_failure_ratio or 0.0
        for sub_task_node in _itertools.chain(yielded_sub_tasks,
                                              upstream_nodes):
            if sub_task_node in visited_nodes:
                continue
            visited_nodes.add(sub_task_node)

            # Generate an id that's unique in the document (if the same task is used multiple times with
            # different resources, executable_sdk_object.id will be the same but generated node_ids should not
            # be.
            safe_task_id = _six.text_type(
                sub_task_node.executable_sdk_object.id)
            if safe_task_id in generated_ids:
                new_count = generated_ids[
                    safe_task_id] = generated_ids[safe_task_id] + 1
            else:
                new_count = generated_ids[safe_task_id] = 0
            unique_node_id = _dnsify("{}-{}".format(safe_task_id, new_count))

            # If the task can run as an array job, group its instances together. Otherwise, keep each invocation as a
            # separate node.
            if SdkDynamicTask._can_run_as_array(
                    sub_task_node.executable_sdk_object.type):
                if sub_task_node.executable_sdk_object in array_job_index:
                    array_job, node = array_job_index[
                        sub_task_node.executable_sdk_object]
                    array_job.size += 1
                    array_job.min_successes = int(
                        math.ceil(
                            (1 - effective_failure_ratio) * array_job.size))
                else:
                    array_job = self._create_array_job(
                        inputs_prefix=unique_node_id)
                    node = sub_task_node.assign_id_and_return(unique_node_id)
                    array_job_index[sub_task_node.executable_sdk_object] = (
                        array_job, node)

                node_index = _six.text_type(array_job.size - 1)
                for k, node_output in _six.iteritems(sub_task_node.outputs):
                    if not node_output.sdk_node.id:
                        node_output.sdk_node.assign_id_and_return(node.id)
                    node_output.var = "[{}].{}".format(node_index,
                                                       node_output.var)

                # Upload inputs to working directory under /array_job.input_ref/<index>/inputs.pb
                input_path = _os.path.join(node.id, node_index,
                                           _constants.INPUT_FILE_NAME)
                generated_files[input_path] = _literal_models.LiteralMap(
                    literals={
                        binding.var: binding.binding.to_literal_model()
                        for binding in sub_task_node.inputs
                    })
            else:
                node = sub_task_node.assign_id_and_return(unique_node_id)

                tasks.append(sub_task_node.executable_sdk_object)
                nodes.append(node)

                for k, node_output in _six.iteritems(sub_task_node.outputs):
                    if not node_output.sdk_node.id:
                        node_output.sdk_node.assign_id_and_return(node.id)

                # Upload inputs to working directory under /array_job.input_ref/inputs.pb
                input_path = _os.path.join(node.id, _constants.INPUT_FILE_NAME)
                generated_files[input_path] = _literal_models.LiteralMap(
                    literals={
                        binding.var: binding.binding.to_literal_model()
                        for binding in sub_task_node.inputs
                    })

        # assign custom field to the ArrayJob properties computed.
        for task, (array_job, _) in _six.iteritems(array_job_index):
            # TODO: Reconstruct task template object instead of modifying an existing one?
            tasks.append(
                task.assign_custom_and_return(
                    array_job.to_dict()).assign_type_and_return(
                        _constants.SdkTaskType.CONTAINER_ARRAY_TASK))

        # min_successes is absolute, it's computed as the reverse of allowed_failure_ratio and multiplied by the
        # total length of tasks to get an absolute count.
        nodes.extend([
            array_job_node for (_, array_job_node) in array_job_index.values()
        ])
        dynamic_job_spec = _dynamic_job.DynamicJobSpec(
            min_successes=len(nodes),
            tasks=tasks,
            nodes=nodes,
            outputs=output_bindings,
            subworkflows=[])

        return dynamic_job_spec, generated_files
示例#17
0
def test_workflow_closure():
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    typed_interface = _interface.TypedInterface(
        {'a': _interface.Variable(int_type, "description1")}, {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })

    b0 = _literals.Binding(
        'a',
        _literals.BindingData(scalar=_literals.Scalar(
            primitive=_literals.Primitive(integer=5))))
    b1 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'b')))
    b2 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'c')))

    node_metadata = _workflow.NodeMetadata(name='node1',
                                           timeout=timedelta(seconds=10),
                                           retries=_literals.RetryStrategy(0))

    task_metadata = _task.TaskMetadata(
        True,
        _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                              "1.0.0", "python"), timedelta(days=1),
        _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!")

    cpu_resource = _task.Resources.ResourceEntry(
        _task.Resources.ResourceName.CPU, "1")
    resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource])

    task = _task.TaskTemplate(
        _identifier.Identifier(_identifier.ResourceType.TASK, "project",
                               "domain", "name", "version"),
        "python",
        task_metadata,
        typed_interface, {
            'a': 1,
            'b': {
                'c': 2,
                'd': 3
            }
        },
        container=_task.Container("my_image", ["this", "is", "a", "cmd"],
                                  ["this", "is", "an", "arg"], resources, {},
                                  {}))

    task_node = _workflow.TaskNode(task.id)
    node = _workflow.Node(id='my_node',
                          metadata=node_metadata,
                          inputs=[b0],
                          upstream_node_ids=[],
                          output_aliases=[],
                          task_node=task_node)

    template = _workflow.WorkflowTemplate(
        id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project",
                                  "domain", "name", "version"),
        metadata=_workflow.WorkflowMetadata(),
        interface=typed_interface,
        nodes=[node],
        outputs=[b1, b2],
    )

    obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task])
    assert len(obj.tasks) == 1

    obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
示例#18
0
def test_sdk_node_from_task():
    @_tasks.inputs(a=_types.Types.Integer)
    @_tasks.outputs(b=_types.Types.Integer)
    @_tasks.python_task()
    def testy_test(wf_params, a, b):
        pass

    n = _nodes.SdkNode(
        "n",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )

    assert n.id == "n"
    assert len(n.inputs) == 1
    assert n.inputs[0].var == "a"
    assert n.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n.outputs) == 1
    assert "b" in n.outputs
    assert n.outputs["b"].node_id == "n"
    assert n.outputs["b"].var == "b"
    assert n.outputs["b"].sdk_node == n
    assert n.outputs["b"].sdk_type == _types.Types.Integer
    assert n.metadata.name == "abc"
    assert n.metadata.retries.retries == 3
    assert n.metadata.interruptible is None
    assert len(n.upstream_nodes) == 0
    assert len(n.upstream_node_ids) == 0
    assert len(n.output_aliases) == 0

    n2 = _nodes.SdkNode(
        "n2",
        [n],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), n.outputs.b),
            )
        ],
        _core_workflow_models.NodeMetadata("abc2",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )

    assert n2.id == "n2"
    assert len(n2.inputs) == 1
    assert n2.inputs[0].var == "a"
    assert n2.inputs[0].binding.promise.var == "b"
    assert n2.inputs[0].binding.promise.node_id == "n"
    assert len(n2.outputs) == 1
    assert "b" in n2.outputs
    assert n2.outputs["b"].node_id == "n2"
    assert n2.outputs["b"].var == "b"
    assert n2.outputs["b"].sdk_node == n2
    assert n2.outputs["b"].sdk_type == _types.Types.Integer
    assert n2.metadata.name == "abc2"
    assert n2.metadata.retries.retries == 3
    assert "n" in n2.upstream_node_ids
    assert n in n2.upstream_nodes
    assert len(n2.upstream_nodes) == 1
    assert len(n2.upstream_node_ids) == 1
    assert len(n2.output_aliases) == 0

    # Test right shift operator and late binding
    n3 = _nodes.SdkNode(
        "n3",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc3",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )
    n2 >> n3
    n >> n2 >> n3
    n3 << n2
    n3 << n2 << n

    assert n3.id == "n3"
    assert len(n3.inputs) == 1
    assert n3.inputs[0].var == "a"
    assert n3.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n3.outputs) == 1
    assert "b" in n3.outputs
    assert n3.outputs["b"].node_id == "n3"
    assert n3.outputs["b"].var == "b"
    assert n3.outputs["b"].sdk_node == n3
    assert n3.outputs["b"].sdk_type == _types.Types.Integer
    assert n3.metadata.name == "abc3"
    assert n3.metadata.retries.retries == 3
    assert "n2" in n3.upstream_node_ids
    assert n2 in n3.upstream_nodes
    assert len(n3.upstream_nodes) == 1
    assert len(n3.upstream_node_ids) == 1
    assert len(n3.output_aliases) == 0

    # Test left shift operator and late binding
    n4 = _nodes.SdkNode(
        "n4",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc4",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )

    n4 << n3

    # Test that implicit dependencies don't cause direct dependencies
    n4 << n3 << n2 << n
    n >> n2 >> n3 >> n4

    assert n4.id == "n4"
    assert len(n4.inputs) == 1
    assert n4.inputs[0].var == "a"
    assert n4.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n4.outputs) == 1
    assert "b" in n4.outputs
    assert n4.outputs["b"].node_id == "n4"
    assert n4.outputs["b"].var == "b"
    assert n4.outputs["b"].sdk_node == n4
    assert n4.outputs["b"].sdk_type == _types.Types.Integer
    assert n4.metadata.name == "abc4"
    assert n4.metadata.retries.retries == 3
    assert "n3" in n4.upstream_node_ids
    assert n3 in n4.upstream_nodes
    assert len(n4.upstream_nodes) == 1
    assert len(n4.upstream_node_ids) == 1
    assert len(n4.output_aliases) == 0

    # Add another dependency
    n4 << n2
    assert "n3" in n4.upstream_node_ids
    assert n3 in n4.upstream_nodes
    assert "n2" in n4.upstream_node_ids
    assert n2 in n4.upstream_nodes
    assert len(n4.upstream_nodes) == 2
    assert len(n4.upstream_node_ids) == 2