def test_get_task_execution_outputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_task_execution_data = MagicMock( return_value=_execution_models.TaskExecutionGetDataResponse( execution_data_locations[0], execution_data_locations[1])) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock(return_value=identifier.TaskExecutionIdentifier( identifier.Identifier(identifier.ResourceType.TASK, 'project', 'domain', 'task-name', 'version'), identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", )), 0)) inputs = engine.FlyteTaskExecution(m).get_outputs() assert len(inputs.literals) == 1 assert inputs.literals['b'].scalar.primitive.integer == 2 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( identifier.Identifier(identifier.ResourceType.TASK, 'project', 'domain', 'task-name', 'version'), identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", )), 0))
def test_get_full_task_execution_outputs(mock_client_factory): mock_client = MagicMock() mock_client.get_task_execution_data = MagicMock( return_value=_execution_models.TaskExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP) ) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock( return_value=identifier.TaskExecutionIdentifier( identifier.Identifier( identifier.ResourceType.TASK, "project", "domain", "task-name", "version", ), identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", ), ), 0, ) ) outputs = engine.FlyteTaskExecution(m).get_outputs() assert len(outputs.literals) == 1 assert outputs.literals["b"].scalar.primitive.integer == 2 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( identifier.Identifier( identifier.ResourceType.TASK, "project", "domain", "task-name", "version", ), identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", ), ), 0, ) )
def test_task_node_metadata(): task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") wf_exec_id = identifier.WorkflowExecutionIdentifier( "project", "domain", "name") node_exec_id = identifier.NodeExecutionIdentifier( "node_id", wf_exec_id, ) te_id = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3) ds_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "t1", "abcdef") tag = catalog.CatalogArtifactTag("my-artifact-id", "some name") catalog_metadata = catalog.CatalogMetadata(dataset_id=ds_id, artifact_tag=tag, source_task_execution=te_id) obj = node_execution_models.TaskNodeMetadata(cache_status=0, catalog_key=catalog_metadata) assert obj.cache_status == 0 assert obj.catalog_key == catalog_metadata obj2 = node_execution_models.TaskNodeMetadata.from_flyte_idl( obj.to_flyte_idl()) assert obj2 == obj
def test_task_execution_identifier(): task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") node_exec_id = identifier.NodeExecutionIdentifier("node_id", wf_exec_id,) obj = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3) assert obj.retry_attempt == 3 assert obj.task_id == task_id assert obj.node_execution_id == node_exec_id obj2 = identifier.TaskExecutionIdentifier.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj assert obj2.retry_attempt == 3 assert obj2.task_id == task_id assert obj2.node_execution_id == node_exec_id
def test_task_execution_identifier(): task_id = _identifier.Identifier(_core_identifier.ResourceType.TASK, "project", "domain", "name", "version") node_execution_id = _core_identifier.NodeExecutionIdentifier( node_id="n0", execution_id=_core_identifier.WorkflowExecutionIdentifier( "project", "domain", "name")) identifier = _identifier.TaskExecutionIdentifier( task_id=task_id, node_execution_id=node_execution_id, retry_attempt=0, ) assert identifier == _identifier.TaskExecutionIdentifier.from_urn( "te:project:domain:name:n0:project:domain:name:version:0") assert identifier == _identifier.TaskExecutionIdentifier.promote_from_model( _core_identifier.TaskExecutionIdentifier(task_id, node_execution_id, 0)) assert identifier.__str__( ) == "te:project:domain:name:n0:project:domain:name:version:0"