Exemplo n.º 1
0
    def __call__(self, *args, **input_map):
        """
        :param list[T] args: Do not specify.  Kwargs only are supported for this function.
        :param dict[Text,T] input_map: Map of inputs.  Can be statically defined or OutputReference links.
        :rtype: flytekit.common.nodes.SdkNode
        """
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                "When adding a launchplan as a node in a workflow, all inputs must be specified with kwargs only.  We "
                "detected {} positional args.".format(len(args)))

        # Take the default values from the launch plan
        default_inputs = {
            k: v.sdk_default
            for k, v in _six.iteritems(self.default_inputs.parameters)
            if not v.required
        }
        default_inputs.update(input_map)

        bindings, upstream_nodes = self.interface.create_bindings_for_inputs(
            default_inputs)

        return _nodes.SdkNode(
            id=None,
            metadata=_workflow_models.NodeMetadata(
                "", _datetime.timedelta(), _literal_models.RetryStrategy(0)),
            bindings=sorted(bindings, key=lambda b: b.var),
            upstream_nodes=upstream_nodes,
            sdk_launch_plan=self,
        )
Exemplo n.º 2
0
def test_non_system_nodes():
    @inputs(a=primitives.Integer)
    @outputs(b=primitives.Integer)
    @python_task()
    def my_task(wf_params, a, b):
        b.set(a + 1)

    my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                         "project", "domain", "my_task",
                                         "version")

    required_input = promise.Input("required", primitives.Integer)

    n1 = my_task(a=required_input).assign_id_and_return("n1")

    n_start = nodes.SdkNode(
        "start-node",
        [],
        [
            _literals.Binding(
                "a",
                interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        None,
        sdk_task=my_task,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )

    non_system_nodes = workflow.SdkWorkflow.get_non_system_nodes([n1, n_start])
    assert len(non_system_nodes) == 1
    assert non_system_nodes[0].id == "n1"
Exemplo n.º 3
0
    def promote_from_model(cls, base_model):
        """
        :param flytekit.models.core.workflow.WorkflowTemplate base_model:
        :rtype: SdkWorkflow
        """
        node_map = {
            n.id: _nodes.SdkNode(
                n.id,
                [],
                n.inputs,
                n.metadata,
                sdk_task=None,
                sdk_workflow=None,
                sdk_branch=
                None  # TODO: Hydrate these objects by reference from the engine.
            )
            for n in base_model.nodes
        }

        for v in _six.itervalues(node_map):
            v.upstream_nodes[:] = [node_map[k] for k in v.upstream_node_ids]

        return cls(
            metadata=base_model.metadata,
            interface=_interface.TypedInterface.promote_from_model(
                base_model.interface),
            nodes=list(node_map.values()),
            outputs=base_model.outputs,
            failure_node=None  # TODO: Implement failure node
        )
Exemplo n.º 4
0
    def __call__(self, *args, **input_map):
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                "When adding a workflow as a node in a workflow, all inputs must be specified with kwargs only.  We "
                "detected {} positional args.".format(len(args))
            )

        # Take the default values from the Inputs
        compiled_inputs = {
            v.name: v.sdk_default
            for v in self.user_inputs if not v.sdk_required
        }
        compiled_inputs.update(input_map)

        bindings, upstream_nodes = self.interface.create_bindings_for_inputs(compiled_inputs)

        node = _nodes.SdkNode(
            id=None,
            metadata=_workflow_models.NodeMetadata("placeholder", _datetime.timedelta(),
                                                   _literal_models.RetryStrategy(0)),
            upstream_nodes=upstream_nodes,
            bindings=sorted(bindings, key=lambda b: b.var),
            sdk_workflow=self
        )
        return node
Exemplo n.º 5
0
    def __call__(self, *args, **input_map):
        """
        :param list[T] args: Do not specify.  Kwargs only are supported for this function.
        :param dict[str, T] input_map: Map of inputs.  Can be statically defined or OutputReference links.
        :rtype: flytekit.common.nodes.SdkNode
        """
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                "When adding a task as a node in a workflow, all inputs must be specified with kwargs only.  We "
                "detected {} positional args.".format(len(args)))

        bindings, upstream_nodes = self.interface.create_bindings_for_inputs(
            input_map)

        # TODO: Remove DEADBEEF
        # One thing to note - this function is not overloaded at the SdkRunnableTask layer, which means 'self' here
        # will sometimes refer to an object that can be executed locally, and other times will refer to something
        # that cannot (ie a pure SdkTask object, fetched from Admin for instance).
        return _nodes.SdkNode(
            id=None,
            metadata=_workflow_model.NodeMetadata(
                "DEADBEEF",
                self.metadata.timeout,
                self.metadata.retries,
                self.metadata.interruptible,
            ),
            bindings=sorted(bindings, key=lambda b: b.var),
            upstream_nodes=upstream_nodes,
            sdk_task=self,
        )
Exemplo n.º 6
0
def test_blah():
    @inputs(a=primitives.Integer)
    @outputs(b=primitives.Integer)
    @python_task()
    def my_task(wf_params, a, b):
        b.set(a + 1)

    my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_task', 'version')

    required_input = promise.Input('required', primitives.Integer)

    n1 = my_task(a=required_input).assign_id_and_return('n1')

    n_start = nodes.SdkNode(
        'start-node',
        [],
        [
            _literals.Binding(
                'a',
                interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3)
            )
        ],
        None,
        sdk_task=my_task,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None
    )

    non_system_nodes = workflow.SdkWorkflow.get_non_system_nodes([n1, n_start])
    assert len(non_system_nodes) == 1
    assert non_system_nodes[0].id == 'n1'
Exemplo n.º 7
0
def _create_hive_job_node(name, hive_job, metadata):
    """
    :param Text name:
    :param _qubole.QuboleHiveJob hive_job: Hive job spec
    :param flytekit.models.task.TaskMetadata metadata: This contains information needed at runtime to determine
        behavior such as whether or not outputs are discoverable, timeouts, and retries.
    :rtype: _nodes.SdkNode:
    """
    return _nodes.SdkNode(id=_six.text_type(_uuid.uuid4()),
                          upstream_nodes=[],
                          bindings=[],
                          metadata=_workflow_model.NodeMetadata(
                              name, metadata.timeout,
                              _literal_models.RetryStrategy(0)),
                          sdk_task=SdkHiveJob(hive_job, metadata))
Exemplo n.º 8
0
def test_sdk_node_from_lp():
    @_tasks.inputs(a=_types.Types.Integer)
    @_tasks.outputs(b=_types.Types.Integer)
    @_tasks.python_task()
    def testy_test(wf_params, a, b):
        pass

    @_workflow.workflow_class
    class test_workflow(object):
        a = _workflow.Input(_types.Types.Integer)
        test = testy_test(a=a)
        b = _workflow.Output(test.outputs.b, sdk_type=_types.Types.Integer)

    lp = test_workflow.create_launch_plan()

    n1 = _nodes.SdkNode(
        "n1",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_launch_plan=lp,
    )

    assert n1.id == "n1"
    assert len(n1.inputs) == 1
    assert n1.inputs[0].var == "a"
    assert n1.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n1.outputs) == 1
    assert "b" in n1.outputs
    assert n1.outputs["b"].node_id == "n1"
    assert n1.outputs["b"].var == "b"
    assert n1.outputs["b"].sdk_node == n1
    assert n1.outputs["b"].sdk_type == _types.Types.Integer
    assert n1.metadata.name == "abc"
    assert n1.metadata.retries.retries == 3
    assert len(n1.upstream_nodes) == 0
    assert len(n1.upstream_node_ids) == 0
    assert len(n1.output_aliases) == 0
Exemplo n.º 9
0
    def __call__(self, *args, **input_map):
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                "When adding a workflow as a node in a workflow, all inputs must be specified with kwargs only.  We "
                "detected {} positional args.".format(len(args)))
        bindings, upstream_nodes = self.interface.create_bindings_for_inputs(
            input_map)

        node = _nodes.SdkNode(
            id=None,
            metadata=_workflow_models.NodeMetadata(
                "placeholder", _datetime.timedelta(),
                _literal_models.RetryStrategy(0)),
            upstream_nodes=upstream_nodes,
            bindings=sorted(bindings, key=lambda b: b.var),
            sdk_workflow=self,
        )
        return node
Exemplo n.º 10
0
    def __call__(self, *args, **input_map):
        """
        :param list[T] args: Do not specify.  Kwargs only are supported for this function.
        :param dict[str, T] input_map: Map of inputs.  Can be statically defined or OutputReference links.
        :rtype: flytekit.common.nodes.SdkNode
        """
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                "When adding a task as a node in a workflow, all inputs must be specified with kwargs only.  We "
                "detected {} positional args.".format(len(args)))

        bindings, upstream_nodes = self.interface.create_bindings_for_inputs(
            input_map)

        # TODO: Remove DEADBEEF
        return _nodes.SdkNode(id=None,
                              metadata=_workflow_model.NodeMetadata(
                                  "DEADBEEF", self.metadata.timeout,
                                  self.metadata.retries),
                              bindings=sorted(bindings, key=lambda b: b.var),
                              upstream_nodes=upstream_nodes,
                              sdk_task=self)
Exemplo n.º 11
0
def test_sdk_node_from_task():
    @_tasks.inputs(a=_types.Types.Integer)
    @_tasks.outputs(b=_types.Types.Integer)
    @_tasks.python_task()
    def testy_test(wf_params, a, b):
        pass

    n = _nodes.SdkNode(
        "n",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )

    assert n.id == "n"
    assert len(n.inputs) == 1
    assert n.inputs[0].var == "a"
    assert n.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n.outputs) == 1
    assert "b" in n.outputs
    assert n.outputs["b"].node_id == "n"
    assert n.outputs["b"].var == "b"
    assert n.outputs["b"].sdk_node == n
    assert n.outputs["b"].sdk_type == _types.Types.Integer
    assert n.metadata.name == "abc"
    assert n.metadata.retries.retries == 3
    assert n.metadata.interruptible is None
    assert len(n.upstream_nodes) == 0
    assert len(n.upstream_node_ids) == 0
    assert len(n.output_aliases) == 0

    n2 = _nodes.SdkNode(
        "n2",
        [n],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), n.outputs.b),
            )
        ],
        _core_workflow_models.NodeMetadata("abc2",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )

    assert n2.id == "n2"
    assert len(n2.inputs) == 1
    assert n2.inputs[0].var == "a"
    assert n2.inputs[0].binding.promise.var == "b"
    assert n2.inputs[0].binding.promise.node_id == "n"
    assert len(n2.outputs) == 1
    assert "b" in n2.outputs
    assert n2.outputs["b"].node_id == "n2"
    assert n2.outputs["b"].var == "b"
    assert n2.outputs["b"].sdk_node == n2
    assert n2.outputs["b"].sdk_type == _types.Types.Integer
    assert n2.metadata.name == "abc2"
    assert n2.metadata.retries.retries == 3
    assert "n" in n2.upstream_node_ids
    assert n in n2.upstream_nodes
    assert len(n2.upstream_nodes) == 1
    assert len(n2.upstream_node_ids) == 1
    assert len(n2.output_aliases) == 0

    # Test right shift operator and late binding
    n3 = _nodes.SdkNode(
        "n3",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc3",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )
    n2 >> n3
    n >> n2 >> n3
    n3 << n2
    n3 << n2 << n

    assert n3.id == "n3"
    assert len(n3.inputs) == 1
    assert n3.inputs[0].var == "a"
    assert n3.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n3.outputs) == 1
    assert "b" in n3.outputs
    assert n3.outputs["b"].node_id == "n3"
    assert n3.outputs["b"].var == "b"
    assert n3.outputs["b"].sdk_node == n3
    assert n3.outputs["b"].sdk_type == _types.Types.Integer
    assert n3.metadata.name == "abc3"
    assert n3.metadata.retries.retries == 3
    assert "n2" in n3.upstream_node_ids
    assert n2 in n3.upstream_nodes
    assert len(n3.upstream_nodes) == 1
    assert len(n3.upstream_node_ids) == 1
    assert len(n3.output_aliases) == 0

    # Test left shift operator and late binding
    n4 = _nodes.SdkNode(
        "n4",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc4",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )

    n4 << n3

    # Test that implicit dependencies don't cause direct dependencies
    n4 << n3 << n2 << n
    n >> n2 >> n3 >> n4

    assert n4.id == "n4"
    assert len(n4.inputs) == 1
    assert n4.inputs[0].var == "a"
    assert n4.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n4.outputs) == 1
    assert "b" in n4.outputs
    assert n4.outputs["b"].node_id == "n4"
    assert n4.outputs["b"].var == "b"
    assert n4.outputs["b"].sdk_node == n4
    assert n4.outputs["b"].sdk_type == _types.Types.Integer
    assert n4.metadata.name == "abc4"
    assert n4.metadata.retries.retries == 3
    assert "n3" in n4.upstream_node_ids
    assert n3 in n4.upstream_nodes
    assert len(n4.upstream_nodes) == 1
    assert len(n4.upstream_node_ids) == 1
    assert len(n4.output_aliases) == 0

    # Add another dependency
    n4 << n2
    assert "n3" in n4.upstream_node_ids
    assert n3 in n4.upstream_nodes
    assert "n2" in n4.upstream_node_ids
    assert n2 in n4.upstream_nodes
    assert len(n4.upstream_nodes) == 2
    assert len(n4.upstream_node_ids) == 2