Пример #1
0
def test_hive_task_dynamic_job_spec_generation():
    with _common_utils.AutoDeletingTempDir(
            "user_dir") as user_working_directory:
        context = _common_engine.EngineContext(
            execution_id=WorkflowExecutionIdentifier(project="unit_test",
                                                     domain="unit_test",
                                                     name="unit_test"),
            execution_date=_datetime.utcnow(),
            stats=None,  # TODO: A mock stats object that we can read later.
            logging=
            _logging,  # TODO: A mock logging object that we can read later.
            tmp_dir=user_working_directory,
        )
        dj_spec = two_queries._produce_dynamic_job_spec(
            context, _literals.LiteralMap(literals={}))

        # Bindings
        assert len(dj_spec.outputs[0].binding.collection.bindings) == 2
        assert isinstance(
            dj_spec.outputs[0].binding.collection.bindings[0].scalar.schema,
            Schema)
        assert isinstance(
            dj_spec.outputs[0].binding.collection.bindings[1].scalar.schema,
            Schema)

        # Custom field is filled in
        assert len(dj_spec.tasks[0].custom) > 0
Пример #2
0
def test_hive_task_query_generation():
    with _common_utils.AutoDeletingTempDir(
            "user_dir") as user_working_directory:
        context = _common_engine.EngineContext(
            execution_id=WorkflowExecutionIdentifier(project="unit_test",
                                                     domain="unit_test",
                                                     name="unit_test"),
            execution_date=_datetime.utcnow(),
            stats=None,  # TODO: A mock stats object that we can read later.
            logging=
            _logging,  # TODO: A mock logging object that we can read later.
            tmp_dir=user_working_directory,
        )
        references = {
            name: _task_output.OutputReference(
                _type_helpers.get_sdk_type_from_literal_type(variable.type))
            for name, variable in _six.iteritems(two_queries.interface.outputs)
        }

        qubole_hive_jobs = two_queries._generate_plugin_objects(
            context, references)
        assert len(qubole_hive_jobs) == 2

        # deprecated, collection is only here for backwards compatibility
        assert len(qubole_hive_jobs[0].query_collection.queries) == 1
        assert len(qubole_hive_jobs[1].query_collection.queries) == 1

        # The output references should now have the same fake S3 path as the formatted queries
        assert references["hive_results"].value[0].uri != ""
        assert references["hive_results"].value[1].uri != ""
        assert references["hive_results"].value[0].uri in qubole_hive_jobs[
            0].query.query
        assert references["hive_results"].value[1].uri in qubole_hive_jobs[
            1].query.query
Пример #3
0
 def _execute_user_code(self, inputs):
     """
     :param flytekit.models.literals.LiteralMap inputs:
     :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
     """
     with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory:
         return self.sdk_task.execute(
             _common_engine.EngineContext(
                 execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"),
                 execution_date=_datetime.utcnow(),
                 stats=MockStats(),
                 logging=_logging,  # TODO: A mock logging object that we can read later.
                 tmp_dir=user_working_directory,
             ),
             inputs,
         )
    def setUp(self):
        with _utils.AutoDeletingTempDir("input_dir") as input_dir:

            self._task_input = _literals.LiteralMap({
                "input_1":
                _literals.Literal(scalar=_literals.Scalar(
                    primitive=_literals.Primitive(integer=1)))
            })

            self._context = _common_engine.EngineContext(
                execution_id=WorkflowExecutionIdentifier(project="unit_test",
                                                         domain="unit_test",
                                                         name="unit_test"),
                execution_date=_datetime.datetime.utcnow(),
                stats=MockStats(),
                logging=None,
                tmp_dir=input_dir.name,
            )

            # Defining the distributed training task without specifying an output-persist
            # predicate (so it will use the default)
            @inputs(input_1=Types.Integer)
            @outputs(model=Types.Blob)
            @custom_training_job_task(
                training_job_resource_config=TrainingJobResourceConfig(
                    instance_type="ml.m4.xlarge",
                    instance_count=2,
                    volume_size_in_gb=25,
                ),
                algorithm_specification=AlgorithmSpecification(
                    input_mode=InputMode.FILE,
                    input_content_type=InputContentType.TEXT_CSV,
                    metric_definitions=[
                        MetricDefinition(name="Validation error",
                                         regex="validation:error")
                    ],
                ),
            )
            def my_distributed_task(wf_params, input_1, model):
                pass

            self._my_distributed_task = my_distributed_task
            assert type(self._my_distributed_task) == CustomTrainingJobTask
Пример #5
0
    def execute(self, inputs, context=None):
        """
        Just execute the task and write the outputs to where they belong
        :param flytekit.models.literals.LiteralMap inputs:
        :param dict[Text, Text] context:
        :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
        """

        with _common_utils.AutoDeletingTempDir("engine_dir") as temp_dir:
            with _common_utils.AutoDeletingTempDir("task_dir") as task_dir:
                with _data_proxy.LocalWorkingDirectoryContext(task_dir):
                    with _data_proxy.RemoteDataContext():
                        output_file_dict = dict()

                        # This sets the logging level for user code and is the only place an sdk setting gets
                        # used at runtime.  Optionally, Propeller can set an internal config setting which
                        # takes precedence.
                        log_level = _internal_config.LOGGING_LEVEL.get(
                        ) or _sdk_config.LOGGING_LEVEL.get()
                        _logging.getLogger().setLevel(log_level)

                        try:
                            output_file_dict = self.sdk_task.execute(
                                _common_engine.EngineContext(
                                    execution_id=_identifier.
                                    WorkflowExecutionIdentifier(
                                        project=_internal_config.
                                        EXECUTION_PROJECT.get(),
                                        domain=_internal_config.
                                        EXECUTION_DOMAIN.get(),
                                        name=_internal_config.EXECUTION_NAME.
                                        get()),
                                    execution_date=_datetime.utcnow(),
                                    stats=_get_stats(
                                        # Stats metric path will be:
                                        # registration_project.registration_domain.app.module.task_name.user_stats
                                        # and it will be tagged with execution-level values for project/domain/wf/lp
                                        "{}.{}.{}.user_stats".format(
                                            _internal_config.TASK_PROJECT.get(
                                            )
                                            or _internal_config.PROJECT.get(),
                                            _internal_config.TASK_DOMAIN.get()
                                            or _internal_config.DOMAIN.get(),
                                            _internal_config.TASK_NAME.get()
                                            or _internal_config.NAME.get()),
                                        tags={
                                            'exec_project':
                                            _internal_config.EXECUTION_PROJECT.
                                            get(),
                                            'exec_domain':
                                            _internal_config.EXECUTION_DOMAIN.
                                            get(),
                                            'exec_workflow':
                                            _internal_config.
                                            EXECUTION_WORKFLOW.get(),
                                            'exec_launchplan':
                                            _internal_config.
                                            EXECUTION_LAUNCHPLAN.get(),
                                            'api_version':
                                            _api_version
                                        }),
                                    logging=_logging,
                                    tmp_dir=task_dir),
                                inputs)
                        except _exception_scopes.FlyteScopedException as e:
                            _logging.error(
                                "!!! Begin Error Captured by Flyte !!!")
                            output_file_dict[
                                _constants.
                                ERROR_FILE_NAME] = _error_models.ErrorDocument(
                                    _error_models.ContainerError(
                                        e.error_code, e.verbose_message,
                                        e.kind))
                            _logging.error(e.verbose_message)
                            _logging.error(
                                "!!! End Error Captured by Flyte !!!")
                        except Exception:
                            _logging.error(
                                "!!! Begin Unknown System Error Captured by Flyte !!!"
                            )
                            exc_str = _traceback.format_exc()
                            output_file_dict[
                                _constants.
                                ERROR_FILE_NAME] = _error_models.ErrorDocument(
                                    _error_models.ContainerError(
                                        "SYSTEM:Unknown", exc_str,
                                        _error_models.ContainerError.Kind.
                                        RECOVERABLE))
                            _logging.error(exc_str)
                            _logging.error(
                                "!!! End Error Captured by Flyte !!!")
                        finally:
                            for k, v in _six.iteritems(output_file_dict):
                                _common_utils.write_proto_to_file(
                                    v.to_flyte_idl(),
                                    _os.path.join(temp_dir.name, k))
                            _data_proxy.Data.put_data(temp_dir.name,
                                                      context['output_prefix'],
                                                      is_multipart=True)