Exemple #1
0
def test_get_task_execution_outputs(mock_client_factory,
                                    execution_data_locations):
    mock_client = MagicMock()
    mock_client.get_task_execution_data = MagicMock(
        return_value=_execution_models.TaskExecutionGetDataResponse(
            execution_data_locations[0], execution_data_locations[1]))
    mock_client_factory.return_value = mock_client

    m = MagicMock()
    type(m).id = PropertyMock(return_value=identifier.TaskExecutionIdentifier(
        identifier.Identifier(identifier.ResourceType.TASK, 'project',
                              'domain', 'task-name', 'version'),
        identifier.NodeExecutionIdentifier(
            "node-a",
            identifier.WorkflowExecutionIdentifier(
                "project",
                "domain",
                "name",
            )), 0))

    inputs = engine.FlyteTaskExecution(m).get_outputs()
    assert len(inputs.literals) == 1
    assert inputs.literals['b'].scalar.primitive.integer == 2
    mock_client.get_task_execution_data.assert_called_once_with(
        identifier.TaskExecutionIdentifier(
            identifier.Identifier(identifier.ResourceType.TASK, 'project',
                                  'domain', 'task-name', 'version'),
            identifier.NodeExecutionIdentifier(
                "node-a",
                identifier.WorkflowExecutionIdentifier(
                    "project",
                    "domain",
                    "name",
                )), 0))
Exemple #2
0
def test_get_full_task_execution_outputs(mock_client_factory):
    mock_client = MagicMock()
    mock_client.get_task_execution_data = MagicMock(
        return_value=_execution_models.TaskExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP)
    )
    mock_client_factory.return_value = mock_client

    m = MagicMock()
    type(m).id = PropertyMock(
        return_value=identifier.TaskExecutionIdentifier(
            identifier.Identifier(
                identifier.ResourceType.TASK,
                "project",
                "domain",
                "task-name",
                "version",
            ),
            identifier.NodeExecutionIdentifier(
                "node-a",
                identifier.WorkflowExecutionIdentifier(
                    "project",
                    "domain",
                    "name",
                ),
            ),
            0,
        )
    )

    outputs = engine.FlyteTaskExecution(m).get_outputs()
    assert len(outputs.literals) == 1
    assert outputs.literals["b"].scalar.primitive.integer == 2
    mock_client.get_task_execution_data.assert_called_once_with(
        identifier.TaskExecutionIdentifier(
            identifier.Identifier(
                identifier.ResourceType.TASK,
                "project",
                "domain",
                "task-name",
                "version",
            ),
            identifier.NodeExecutionIdentifier(
                "node-a",
                identifier.WorkflowExecutionIdentifier(
                    "project",
                    "domain",
                    "name",
                ),
            ),
            0,
        )
    )
def test_task_node_metadata():
    task_id = identifier.Identifier(identifier.ResourceType.TASK, "project",
                                    "domain", "name", "version")
    wf_exec_id = identifier.WorkflowExecutionIdentifier(
        "project", "domain", "name")
    node_exec_id = identifier.NodeExecutionIdentifier(
        "node_id",
        wf_exec_id,
    )
    te_id = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3)
    ds_id = identifier.Identifier(identifier.ResourceType.TASK, "project",
                                  "domain", "t1", "abcdef")
    tag = catalog.CatalogArtifactTag("my-artifact-id", "some name")
    catalog_metadata = catalog.CatalogMetadata(dataset_id=ds_id,
                                               artifact_tag=tag,
                                               source_task_execution=te_id)

    obj = node_execution_models.TaskNodeMetadata(cache_status=0,
                                                 catalog_key=catalog_metadata)
    assert obj.cache_status == 0
    assert obj.catalog_key == catalog_metadata

    obj2 = node_execution_models.TaskNodeMetadata.from_flyte_idl(
        obj.to_flyte_idl())
    assert obj2 == obj
Exemple #4
0
def test_task_execution_identifier():
    task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version")
    wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name")
    node_exec_id = identifier.NodeExecutionIdentifier("node_id", wf_exec_id,)
    obj = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3)
    assert obj.retry_attempt == 3
    assert obj.task_id == task_id
    assert obj.node_execution_id == node_exec_id

    obj2 = identifier.TaskExecutionIdentifier.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
    assert obj2.retry_attempt == 3
    assert obj2.task_id == task_id
    assert obj2.node_execution_id == node_exec_id
Exemple #5
0
def test_task_execution_identifier():
    task_id = _identifier.Identifier(_core_identifier.ResourceType.TASK,
                                     "project", "domain", "name", "version")
    node_execution_id = _core_identifier.NodeExecutionIdentifier(
        node_id="n0",
        execution_id=_core_identifier.WorkflowExecutionIdentifier(
            "project", "domain", "name"))
    identifier = _identifier.TaskExecutionIdentifier(
        task_id=task_id,
        node_execution_id=node_execution_id,
        retry_attempt=0,
    )
    assert identifier == _identifier.TaskExecutionIdentifier.from_urn(
        "te:project:domain:name:n0:project:domain:name:version:0")
    assert identifier == _identifier.TaskExecutionIdentifier.promote_from_model(
        _core_identifier.TaskExecutionIdentifier(task_id, node_execution_id,
                                                 0))
    assert identifier.__str__(
    ) == "te:project:domain:name:n0:project:domain:name:version:0"