Example #1
0
def test_get_task_execution_outputs(mock_client_factory,
                                    execution_data_locations):
    mock_client = MagicMock()
    mock_client.get_task_execution_data = MagicMock(
        return_value=_execution_models.TaskExecutionGetDataResponse(
            execution_data_locations[0], execution_data_locations[1]))
    mock_client_factory.return_value = mock_client

    m = MagicMock()
    type(m).id = PropertyMock(return_value=identifier.TaskExecutionIdentifier(
        identifier.Identifier(identifier.ResourceType.TASK, 'project',
                              'domain', 'task-name', 'version'),
        identifier.NodeExecutionIdentifier(
            "node-a",
            identifier.WorkflowExecutionIdentifier(
                "project",
                "domain",
                "name",
            )), 0))

    inputs = engine.FlyteTaskExecution(m).get_outputs()
    assert len(inputs.literals) == 1
    assert inputs.literals['b'].scalar.primitive.integer == 2
    mock_client.get_task_execution_data.assert_called_once_with(
        identifier.TaskExecutionIdentifier(
            identifier.Identifier(identifier.ResourceType.TASK, 'project',
                                  'domain', 'task-name', 'version'),
            identifier.NodeExecutionIdentifier(
                "node-a",
                identifier.WorkflowExecutionIdentifier(
                    "project",
                    "domain",
                    "name",
                )), 0))
Example #2
0
def test_get_full_task_execution_outputs(mock_client_factory):
    mock_client = MagicMock()
    mock_client.get_task_execution_data = MagicMock(
        return_value=_execution_models.TaskExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP)
    )
    mock_client_factory.return_value = mock_client

    m = MagicMock()
    type(m).id = PropertyMock(
        return_value=identifier.TaskExecutionIdentifier(
            identifier.Identifier(
                identifier.ResourceType.TASK,
                "project",
                "domain",
                "task-name",
                "version",
            ),
            identifier.NodeExecutionIdentifier(
                "node-a",
                identifier.WorkflowExecutionIdentifier(
                    "project",
                    "domain",
                    "name",
                ),
            ),
            0,
        )
    )

    outputs = engine.FlyteTaskExecution(m).get_outputs()
    assert len(outputs.literals) == 1
    assert outputs.literals["b"].scalar.primitive.integer == 2
    mock_client.get_task_execution_data.assert_called_once_with(
        identifier.TaskExecutionIdentifier(
            identifier.Identifier(
                identifier.ResourceType.TASK,
                "project",
                "domain",
                "task-name",
                "version",
            ),
            identifier.NodeExecutionIdentifier(
                "node-a",
                identifier.WorkflowExecutionIdentifier(
                    "project",
                    "domain",
                    "name",
                ),
            ),
            0,
        )
    )