def test_get_task_execution_outputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_task_execution_data = MagicMock( return_value=_execution_models.TaskExecutionGetDataResponse( execution_data_locations[0], execution_data_locations[1])) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock(return_value=identifier.TaskExecutionIdentifier( identifier.Identifier(identifier.ResourceType.TASK, 'project', 'domain', 'task-name', 'version'), identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", )), 0)) inputs = engine.FlyteTaskExecution(m).get_outputs() assert len(inputs.literals) == 1 assert inputs.literals['b'].scalar.primitive.integer == 2 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( identifier.Identifier(identifier.ResourceType.TASK, 'project', 'domain', 'task-name', 'version'), identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", )), 0))
def test_task_execution_data_response(): input_blob = _common_models.UrlBlob("in", 1) output_blob = _common_models.UrlBlob("out", 2) obj = _execution.TaskExecutionGetDataResponse(input_blob, output_blob) obj2 = _execution.TaskExecutionGetDataResponse.from_flyte_idl( obj.to_flyte_idl()) assert obj == obj2 assert obj2.inputs == input_blob assert obj2.outputs == output_blob
def test_get_full_task_execution_outputs(mock_client_factory): mock_client = MagicMock() mock_client.get_task_execution_data = MagicMock( return_value=_execution_models.TaskExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP) ) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock( return_value=identifier.TaskExecutionIdentifier( identifier.Identifier( identifier.ResourceType.TASK, "project", "domain", "task-name", "version", ), identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", ), ), 0, ) ) outputs = engine.FlyteTaskExecution(m).get_outputs() assert len(outputs.literals) == 1 assert outputs.literals["b"].scalar.primitive.integer == 2 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( identifier.Identifier( identifier.ResourceType.TASK, "project", "domain", "task-name", "version", ), identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", ), ), 0, ) )