Пример #1
0
def test_hive_task_dynamic_job_spec_generation():
    with _common_utils.AutoDeletingTempDir(
            "user_dir") as user_working_directory:
        context = _common_engine.EngineContext(
            execution_id=WorkflowExecutionIdentifier(project="unit_test",
                                                     domain="unit_test",
                                                     name="unit_test"),
            execution_date=_datetime.utcnow(),
            stats=None,  # TODO: A mock stats object that we can read later.
            logging=
            _logging,  # TODO: A mock logging object that we can read later.
            tmp_dir=user_working_directory,
        )
        dj_spec = two_queries._produce_dynamic_job_spec(
            context, _literals.LiteralMap(literals={}))

        # Bindings
        assert len(dj_spec.outputs[0].binding.collection.bindings) == 2
        assert isinstance(
            dj_spec.outputs[0].binding.collection.bindings[0].scalar.schema,
            Schema)
        assert isinstance(
            dj_spec.outputs[0].binding.collection.bindings[1].scalar.schema,
            Schema)

        # Custom field is filled in
        assert len(dj_spec.tasks[0].custom) > 0
Пример #2
0
def test_hive_task_query_generation():
    with _common_utils.AutoDeletingTempDir(
            "user_dir") as user_working_directory:
        context = _common_engine.EngineContext(
            execution_id=WorkflowExecutionIdentifier(project="unit_test",
                                                     domain="unit_test",
                                                     name="unit_test"),
            execution_date=_datetime.utcnow(),
            stats=None,  # TODO: A mock stats object that we can read later.
            logging=
            _logging,  # TODO: A mock logging object that we can read later.
            tmp_dir=user_working_directory,
        )
        references = {
            name: _task_output.OutputReference(
                _type_helpers.get_sdk_type_from_literal_type(variable.type))
            for name, variable in _six.iteritems(two_queries.interface.outputs)
        }

        qubole_hive_jobs = two_queries._generate_plugin_objects(
            context, references)
        assert len(qubole_hive_jobs) == 2

        # deprecated, collection is only here for backwards compatibility
        assert len(qubole_hive_jobs[0].query_collection.queries) == 1
        assert len(qubole_hive_jobs[1].query_collection.queries) == 1

        # The output references should now have the same fake S3 path as the formatted queries
        assert references["hive_results"].value[0].uri != ""
        assert references["hive_results"].value[1].uri != ""
        assert references["hive_results"].value[0].uri in qubole_hive_jobs[
            0].query.query
        assert references["hive_results"].value[1].uri in qubole_hive_jobs[
            1].query.query
Пример #3
0
 def _execute_user_code(self, inputs):
     """
     :param flytekit.models.literals.LiteralMap inputs:
     :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
     """
     with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory:
         return self.sdk_task.execute(
             _common_engine.EngineContext(
                 execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"),
                 execution_date=_datetime.utcnow(),
                 stats=MockStats(),
                 logging=_logging,  # TODO: A mock logging object that we can read later.
                 tmp_dir=user_working_directory,
             ),
             inputs,
         )
Пример #4
0
def test_remote_fetch_execution(mock_client_manager):
    admin_workflow_execution = Execution(
        id=WorkflowExecutionIdentifier("p1", "d1", "n1"),
        spec=MagicMock(),
        closure=MagicMock(),
    )

    mock_client = MagicMock()
    mock_client.get_execution.return_value = admin_workflow_execution

    remote = FlyteRemote(config=Config.auto(),
                         default_project="p1",
                         default_domain="d1")
    remote._client = mock_client
    flyte_workflow_execution = remote.fetch_execution(name="n1")
    assert flyte_workflow_execution.id == admin_workflow_execution.id
Пример #5
0
    def execute(self, inputs, context=None):
        """
        Just execute the task and write the outputs to where they belong
        :param flytekit.models.literals.LiteralMap inputs:
        :param dict[Text, Text] context:
        :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
        """

        with _common_utils.AutoDeletingTempDir("engine_dir") as temp_dir:
            with _common_utils.AutoDeletingTempDir("task_dir") as task_dir:
                with _data_proxy.LocalWorkingDirectoryContext(task_dir):
                    with _data_proxy.RemoteDataContext():
                        output_file_dict = dict()

                        # This sets the logging level for user code and is the only place an sdk setting gets
                        # used at runtime.  Optionally, Propeller can set an internal config setting which
                        # takes precedence.
                        log_level = _internal_config.LOGGING_LEVEL.get(
                        ) or _sdk_config.LOGGING_LEVEL.get()
                        _logging.getLogger().setLevel(log_level)

                        try:
                            output_file_dict = self.sdk_task.execute(
                                _common_engine.EngineContext(
                                    execution_id=WorkflowExecutionIdentifier(
                                        project=_internal_config.
                                        EXECUTION_PROJECT.get(),
                                        domain=_internal_config.
                                        EXECUTION_DOMAIN.get(),
                                        name=_internal_config.EXECUTION_NAME.
                                        get()),
                                    execution_date=_datetime.utcnow(),
                                    stats=_get_stats(
                                        # Stats metric path will be:
                                        # registration_project.registration_domain.app.module.task_name.user_stats
                                        # and it will be tagged with execution-level values for project/domain/wf/lp
                                        "{}.{}.{}.user_stats".format(
                                            _internal_config.TASK_PROJECT.get(
                                            )
                                            or _internal_config.PROJECT.get(),
                                            _internal_config.TASK_DOMAIN.get()
                                            or _internal_config.DOMAIN.get(),
                                            _internal_config.TASK_NAME.get()
                                            or _internal_config.NAME.get()),
                                        tags={
                                            'exec_project':
                                            _internal_config.EXECUTION_PROJECT.
                                            get(),
                                            'exec_domain':
                                            _internal_config.EXECUTION_DOMAIN.
                                            get(),
                                            'exec_workflow':
                                            _internal_config.
                                            EXECUTION_WORKFLOW.get(),
                                            'exec_launchplan':
                                            _internal_config.
                                            EXECUTION_LAUNCHPLAN.get(),
                                            'api_version':
                                            _api_version
                                        }),
                                    logging=_logging,
                                    tmp_dir=task_dir),
                                inputs)
                        except _exception_scopes.FlyteScopedException as e:
                            _logging.error(
                                "!!! Begin Error Captured by Flyte !!!")
                            output_file_dict[
                                _constants.
                                ERROR_FILE_NAME] = _error_models.ErrorDocument(
                                    _error_models.ContainerError(
                                        e.error_code, e.verbose_message,
                                        e.kind))
                            _logging.error(e.verbose_message)
                            _logging.error(
                                "!!! End Error Captured by Flyte !!!")
                        except Exception:
                            _logging.error(
                                "!!! Begin Unknown System Error Captured by Flyte !!!"
                            )
                            exc_str = _traceback.format_exc()
                            output_file_dict[
                                _constants.
                                ERROR_FILE_NAME] = _error_models.ErrorDocument(
                                    _error_models.ContainerError(
                                        "SYSTEM:Unknown", exc_str,
                                        _error_models.ContainerError.Kind.
                                        RECOVERABLE))
                            _logging.error(exc_str)
                            _logging.error(
                                "!!! End Error Captured by Flyte !!!")
                        finally:
                            for k, v in _six.iteritems(output_file_dict):
                                _common_utils.write_proto_to_file(
                                    v.to_flyte_idl(),
                                    _os.path.join(temp_dir.name, k))
                            _data_proxy.Data.put_data(temp_dir.name,
                                                      context['output_prefix'],
                                                      is_multipart=True)