def __init__(self, project: str, domain: str, name: str, version: str, inputs: Dict[str, Type], outputs: Dict[str, Type]): super().__init__(TaskReference(project, domain, name, version), inputs, outputs) # Reference tasks shouldn't call the parent constructor, but the parent constructor is what sets the resolver self._task_resolver = None
def test_ref_plain_two_outputs(): r1 = ReferenceEntity( TaskReference("proj", "domain", "some.name", "abc"), inputs=kwtypes(a=str, b=int), outputs=kwtypes(x=bool, y=int), ) ctx = context_manager.FlyteContext.current_context() with context_manager.FlyteContextManager.with_context( ctx.with_new_compilation_state()): xx, yy = r1(a="five", b=6) # Note - misnomer, these are not SdkNodes, they are core.Nodes assert xx.ref.node is yy.ref.node assert xx.var == "x" assert yy.var == "y" assert xx.ref.node_id == "n0" assert len(xx.ref.node.bindings) == 2 @task def t2(q: bool, r: int) -> str: return f"q: {q} r: {r}" @workflow def wf1(a: str, b: int) -> str: x_out, y_out = r1(a=a, b=b) return t2(q=x_out, r=y_out) @patch(r1) def inner_test(ref_mock): ref_mock.return_value = (False, 30) x = wf1(a="hello", b=10) assert x == "q: False r: 30" inner_test()
def test_ref_plain_no_outputs(): r1 = ReferenceEntity( TaskReference("proj", "domain", "some.name", "abc"), inputs=kwtypes(a=str, b=int), outputs={}, ) # Reference entities should always raise an exception when not mocked out. with pytest.raises(Exception) as e: r1(a="fdsa", b=3) assert "You must mock this out" in f"{e}" @workflow def wf1(a: str, b: int): r1(a=a, b=b) @patch(r1) def inner_test(ref_mock): ref_mock.return_value = None x = wf1(a="fdsa", b=3) assert x is None inner_test() nt1 = typing.NamedTuple("DummyNamedTuple", t1_int_output=int, c=str) @task def t1(a: int) -> nt1: a = a + 2 return a, "world-" + str(a) @workflow def wf2(a: int): t1_int, c = t1(a=a) r1(a=c, b=t1_int) @patch(r1) def inner_test2(ref_mock): ref_mock.return_value = None x = wf2(a=3) assert x is None ref_mock.assert_called_with(a="world-5", b=5) inner_test2() # Test nodes node_r1 = wf2._nodes[1] assert node_r1._upstream_nodes[0] is wf2._nodes[0]
def __init__( self, project: str, domain: str, name: str, version: str, inputs: Dict[str, Type], outputs: Dict[str, Type] ): super().__init__(TaskReference(project, domain, name, version), inputs, outputs)