Exemplo n.º 1
0
 def _execute_user_code(self, inputs):
     """
     :param flytekit.models.literals.LiteralMap inputs:
     :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
     """
     with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory:
         return self.sdk_task.execute(
             _common_engine.EngineContext(
                 execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"),
                 execution_date=_datetime.utcnow(),
                 stats=MockStats(),
                 logging=_logging,  # TODO: A mock logging object that we can read later.
                 tmp_dir=user_working_directory,
             ),
             inputs,
         )
    def setUp(self):
        with _utils.AutoDeletingTempDir("input_dir") as input_dir:

            self._task_input = _literals.LiteralMap({
                "input_1":
                _literals.Literal(scalar=_literals.Scalar(
                    primitive=_literals.Primitive(integer=1)))
            })

            self._context = _common_engine.EngineContext(
                execution_id=WorkflowExecutionIdentifier(project="unit_test",
                                                         domain="unit_test",
                                                         name="unit_test"),
                execution_date=_datetime.datetime.utcnow(),
                stats=MockStats(),
                logging=None,
                tmp_dir=input_dir.name,
            )

            # Defining the distributed training task without specifying an output-persist
            # predicate (so it will use the default)
            @inputs(input_1=Types.Integer)
            @outputs(model=Types.Blob)
            @custom_training_job_task(
                training_job_resource_config=TrainingJobResourceConfig(
                    instance_type="ml.m4.xlarge",
                    instance_count=2,
                    volume_size_in_gb=25,
                ),
                algorithm_specification=AlgorithmSpecification(
                    input_mode=InputMode.FILE,
                    input_content_type=InputContentType.TEXT_CSV,
                    metric_definitions=[
                        MetricDefinition(name="Validation error",
                                         regex="validation:error")
                    ],
                ),
            )
            def my_distributed_task(wf_params, input_1, model):
                pass

            self._my_distributed_task = my_distributed_task
            assert type(self._my_distributed_task) == CustomTrainingJobTask