Esempio n. 1
0
    def _execute_user_code(self, context, inputs):
        """
        :param flytekit.engines.common.tasks.sagemaker.distribution.DistributedTrainingEngineContext context:
        :param dict[Text, T] inputs: This variable is a bit of a misnomer, since it's both inputs and outputs. The
            dictionary passed here will be passed to the user-defined function, and will have values that are a
            variety of types.  The T's here are Python std values for inputs.  If there isn't a native Python type for
            something (like Schema or Blob), they are the Flyte classes.  For outputs they are OutputReferences.
            (Note that these are not the same OutputReferences as in BindingData's)
        :rtype: Any: the returned object from user code.
        :returns: This function must return a dictionary mapping 'filenames' to Flyte Interface Entities.  These
            entities will be used by the engine to pass data from node to node, populate metadata, etc. etc..  Each
            engine will have different behavior.  For instance, the Flyte engine will upload the entities to a remote
            working directory (with the names provided), which will in turn allow Flyte Propeller to push along the
            workflow.  Where as local engine will merely feed the outputs directly into the next node.
        """

        return _exception_scopes.user_entry_point(self.task_function)(
            _sm_distribution.DistributedTrainingExecutionParam(
                execution_date=context.execution_date,
                # TODO: it might be better to consider passing the full struct
                execution_id=_six.text_type(
                    WorkflowExecutionIdentifier.promote_from_model(
                        context.execution_id)),
                stats=context.stats,
                logging=context.logging,
                tmp_dir=context.working_directory,
                distributed_training_context=context.
                distributed_training_context,
            ),
            **inputs)
Esempio n. 2
0
 def initialize():
     """
     Re-initializes the context and erases the entire context
     """
     # This is supplied so that tasks that rely on Flyte provided param functionality do not fail when run locally
     default_execution_id = _identifier.WorkflowExecutionIdentifier(
         project="local", domain="local", name="local")
     # Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users
     # are already acquainted with
     default_user_space_params = ExecutionParameters(
         execution_id=str(
             _SdkWorkflowExecutionIdentifier.promote_from_model(
                 default_execution_id)),
         execution_date=_datetime.datetime.utcnow(),
         stats=_mock_stats.MockStats(),
         logging=_logging,
         tmp_dir=os.path.join(_sdk_config.LOCAL_SANDBOX.get(),
                              "user_space"),
     )
     default_context = FlyteContext(
         file_access=_data_proxy.default_local_file_access_provider)
     default_context = default_context.with_execution_state(
         default_context.new_execution_state().with_params(
             user_space_params=default_user_space_params)).build()
     default_context.set_stackframe(
         s=FlyteContextManager.get_origin_stackframe())
     FlyteContextManager._OBJS = [default_context]
    def setUp(self):
        with _utils.AutoDeletingTempDir("input_dir") as input_dir:

            self._task_input = _literals.LiteralMap({
                "input_1":
                _literals.Literal(scalar=_literals.Scalar(
                    primitive=_literals.Primitive(integer=1)))
            })

            self._context = _common_engine.EngineContext(
                execution_id=WorkflowExecutionIdentifier(project="unit_test",
                                                         domain="unit_test",
                                                         name="unit_test"),
                execution_date=_datetime.datetime.utcnow(),
                stats=MockStats(),
                logging=None,
                tmp_dir=input_dir.name,
            )

            # Defining the distributed training task without specifying an output-persist
            # predicate (so it will use the default)
            @inputs(input_1=Types.Integer)
            @outputs(model=Types.Blob)
            @custom_training_job_task(
                training_job_resource_config=TrainingJobResourceConfig(
                    instance_type="ml.m4.xlarge",
                    instance_count=2,
                    volume_size_in_gb=25,
                ),
                algorithm_specification=AlgorithmSpecification(
                    input_mode=InputMode.FILE,
                    input_content_type=InputContentType.TEXT_CSV,
                    metric_definitions=[
                        MetricDefinition(name="Validation error",
                                         regex="validation:error")
                    ],
                ),
            )
            def my_distributed_task(wf_params, input_1, model):
                pass

            self._my_distributed_task = my_distributed_task
            assert type(self._my_distributed_task) == CustomTrainingJobTask
Esempio n. 4
0
        elif self._parent is not None:
            return self._parent.flyte_client
        else:
            raise Exception("No flyte_client initialized")


# Hack... we'll think of something better in the future
class FlyteEntities(object):
    entities = []


# This is supplied so that tasks that rely on Flyte provided param functionality do not fail when run locally
default_execution_id = _identifier.WorkflowExecutionIdentifier(project="local",
                                                               domain="local",
                                                               name="local")
# Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users
# are already acquainted with
default_user_space_params = ExecutionParameters(
    execution_id=str(
        _SdkWorkflowExecutionIdentifier.promote_from_model(
            default_execution_id)),
    execution_date=_datetime.datetime.utcnow(),
    stats=_mock_stats.MockStats(),
    logging=_logging,
    tmp_dir=os.path.join(_sdk_config.LOCAL_SANDBOX.get(), "user_space"),
)
default_context = FlyteContext(
    user_space_params=default_user_space_params,
    file_access=_data_proxy.default_local_file_access_provider)
FlyteContext.OBJS.append(default_context)