def test_hive_task_dynamic_job_spec_generation(): with _common_utils.AutoDeletingTempDir( "user_dir") as user_working_directory: context = _common_engine.EngineContext( execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), execution_date=_datetime.utcnow(), stats=None, # TODO: A mock stats object that we can read later. logging= _logging, # TODO: A mock logging object that we can read later. tmp_dir=user_working_directory, ) dj_spec = two_queries._produce_dynamic_job_spec( context, _literals.LiteralMap(literals={})) # Bindings assert len(dj_spec.outputs[0].binding.collection.bindings) == 2 assert isinstance( dj_spec.outputs[0].binding.collection.bindings[0].scalar.schema, Schema) assert isinstance( dj_spec.outputs[0].binding.collection.bindings[1].scalar.schema, Schema) # Custom field is filled in assert len(dj_spec.tasks[0].custom) > 0
def test_hive_task_query_generation(): with _common_utils.AutoDeletingTempDir( "user_dir") as user_working_directory: context = _common_engine.EngineContext( execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), execution_date=_datetime.utcnow(), stats=None, # TODO: A mock stats object that we can read later. logging= _logging, # TODO: A mock logging object that we can read later. tmp_dir=user_working_directory, ) references = { name: _task_output.OutputReference( _type_helpers.get_sdk_type_from_literal_type(variable.type)) for name, variable in _six.iteritems(two_queries.interface.outputs) } qubole_hive_jobs = two_queries._generate_plugin_objects( context, references) assert len(qubole_hive_jobs) == 2 # deprecated, collection is only here for backwards compatibility assert len(qubole_hive_jobs[0].query_collection.queries) == 1 assert len(qubole_hive_jobs[1].query_collection.queries) == 1 # The output references should now have the same fake S3 path as the formatted queries assert references["hive_results"].value[0].uri != "" assert references["hive_results"].value[1].uri != "" assert references["hive_results"].value[0].uri in qubole_hive_jobs[ 0].query.query assert references["hive_results"].value[1].uri in qubole_hive_jobs[ 1].query.query
def _execute_user_code(self, inputs): """ :param flytekit.models.literals.LiteralMap inputs: :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] """ with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory: return self.sdk_task.execute( _common_engine.EngineContext( execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), execution_date=_datetime.utcnow(), stats=MockStats(), logging=_logging, # TODO: A mock logging object that we can read later. tmp_dir=user_working_directory, ), inputs, )
def setUp(self): with _utils.AutoDeletingTempDir("input_dir") as input_dir: self._task_input = _literals.LiteralMap({ "input_1": _literals.Literal(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=1))) }) self._context = _common_engine.EngineContext( execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), execution_date=_datetime.datetime.utcnow(), stats=MockStats(), logging=None, tmp_dir=input_dir.name, ) # Defining the distributed training task without specifying an output-persist # predicate (so it will use the default) @inputs(input_1=Types.Integer) @outputs(model=Types.Blob) @custom_training_job_task( training_job_resource_config=TrainingJobResourceConfig( instance_type="ml.m4.xlarge", instance_count=2, volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, input_content_type=InputContentType.TEXT_CSV, metric_definitions=[ MetricDefinition(name="Validation error", regex="validation:error") ], ), ) def my_distributed_task(wf_params, input_1, model): pass self._my_distributed_task = my_distributed_task assert type(self._my_distributed_task) == CustomTrainingJobTask
def execute(self, inputs, context=None): """ Just execute the task and write the outputs to where they belong :param flytekit.models.literals.LiteralMap inputs: :param dict[Text, Text] context: :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] """ with _common_utils.AutoDeletingTempDir("engine_dir") as temp_dir: with _common_utils.AutoDeletingTempDir("task_dir") as task_dir: with _data_proxy.LocalWorkingDirectoryContext(task_dir): with _data_proxy.RemoteDataContext(): output_file_dict = dict() # This sets the logging level for user code and is the only place an sdk setting gets # used at runtime. Optionally, Propeller can set an internal config setting which # takes precedence. log_level = _internal_config.LOGGING_LEVEL.get( ) or _sdk_config.LOGGING_LEVEL.get() _logging.getLogger().setLevel(log_level) try: output_file_dict = self.sdk_task.execute( _common_engine.EngineContext( execution_id=_identifier. WorkflowExecutionIdentifier( project=_internal_config. EXECUTION_PROJECT.get(), domain=_internal_config. EXECUTION_DOMAIN.get(), name=_internal_config.EXECUTION_NAME. get()), execution_date=_datetime.utcnow(), stats=_get_stats( # Stats metric path will be: # registration_project.registration_domain.app.module.task_name.user_stats # and it will be tagged with execution-level values for project/domain/wf/lp "{}.{}.{}.user_stats".format( _internal_config.TASK_PROJECT.get( ) or _internal_config.PROJECT.get(), _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(), _internal_config.TASK_NAME.get() or _internal_config.NAME.get()), tags={ 'exec_project': _internal_config.EXECUTION_PROJECT. get(), 'exec_domain': _internal_config.EXECUTION_DOMAIN. get(), 'exec_workflow': _internal_config. EXECUTION_WORKFLOW.get(), 'exec_launchplan': _internal_config. EXECUTION_LAUNCHPLAN.get(), 'api_version': _api_version }), logging=_logging, tmp_dir=task_dir), inputs) except _exception_scopes.FlyteScopedException as e: _logging.error( "!!! Begin Error Captured by Flyte !!!") output_file_dict[ _constants. ERROR_FILE_NAME] = _error_models.ErrorDocument( _error_models.ContainerError( e.error_code, e.verbose_message, e.kind)) _logging.error(e.verbose_message) _logging.error( "!!! End Error Captured by Flyte !!!") except Exception: _logging.error( "!!! Begin Unknown System Error Captured by Flyte !!!" ) exc_str = _traceback.format_exc() output_file_dict[ _constants. ERROR_FILE_NAME] = _error_models.ErrorDocument( _error_models.ContainerError( "SYSTEM:Unknown", exc_str, _error_models.ContainerError.Kind. RECOVERABLE)) _logging.error(exc_str) _logging.error( "!!! End Error Captured by Flyte !!!") finally: for k, v in _six.iteritems(output_file_dict): _common_utils.write_proto_to_file( v.to_flyte_idl(), _os.path.join(temp_dir.name, k)) _data_proxy.Data.put_data(temp_dir.name, context['output_prefix'], is_multipart=True)