def _execute_user_code(self, inputs): """ :param flytekit.models.literals.LiteralMap inputs: :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] """ with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory: return self.sdk_task.execute( _common_engine.EngineContext( execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), execution_date=_datetime.utcnow(), stats=MockStats(), logging=_logging, # TODO: A mock logging object that we can read later. tmp_dir=user_working_directory, ), inputs, )
def setUp(self): with _utils.AutoDeletingTempDir("input_dir") as input_dir: self._task_input = _literals.LiteralMap({ "input_1": _literals.Literal(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=1))) }) self._context = _common_engine.EngineContext( execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), execution_date=_datetime.datetime.utcnow(), stats=MockStats(), logging=None, tmp_dir=input_dir.name, ) # Defining the distributed training task without specifying an output-persist # predicate (so it will use the default) @inputs(input_1=Types.Integer) @outputs(model=Types.Blob) @custom_training_job_task( training_job_resource_config=TrainingJobResourceConfig( instance_type="ml.m4.xlarge", instance_count=2, volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, input_content_type=InputContentType.TEXT_CSV, metric_definitions=[ MetricDefinition(name="Validation error", regex="validation:error") ], ), ) def my_distributed_task(wf_params, input_1, model): pass self._my_distributed_task = my_distributed_task assert type(self._my_distributed_task) == CustomTrainingJobTask