def __init__(self,
                 transform_fn,
                 initialize_fn=None,
                 module_name=None,
                 healthcheck_fn=None):
        """Creates and Flask application from a transformer.

        Args:
            transform_fn (function): responsible to make predictions against the model.
                                     Follows the signature:

                * Returns:
                    `sagemaker_containers.transformers.TransformSpec`: named tuple with prediction
                                                                       data.


            initialize_fn (function, optional): this function is called when the Flask application
                                                starts. It doest not have return type or arguments.

            healthcheck_fn (function, optional): function that will be used for healthcheck calls
                                                 when the containers starts, if not specified, it
                                                 will use ping as the default healthcheck call.
                                                 Signature:

                * Returns:
                    `flask.app.Response`: response object with new healthcheck response.

            module_name (str): the module name which implements the worker. If not specified, it
                               will use sagemaker_containers.ServingEnv().module_name as the default
                               module name.
        """
        super(Worker, self).__init__(module_name or env.module_name)

        # the logger is configured after importing the framework library, allowing the framework to
        # configure logging at import time.
        _logging.configure_logger(env.log_level)

        if initialize_fn:
            self.before_first_request(initialize_fn)

        self.add_url_rule(rule="/invocations",
                          endpoint="invocations",
                          view_func=transform_fn,
                          methods=["POST"])
        self.add_url_rule(rule="/ping",
                          endpoint="ping",
                          view_func=healthcheck_fn or default_healthcheck_fn)

        self.request_class = Request
def train():
    try:
        # TODO: iquintero - add error handling for ImportError to let the user know
        # if the framework module is not defined.
        env = sagemaker_containers.training_env()

        framework_name, entry_point_name = env.framework_module.split(':')

        try:
            framework = importlib.import_module(framework_name)
        except:
            logger.error("Import failure loading %s. sys.path=%s" %
                         (framework_name, sys.path))
            raise

        # the logger is configured after importing the framework library, allowing the framework to
        # configure logging at import time.
        _logging.configure_logger(env.log_level)

        logger.info('Imported framework %s', framework_name)

        entry_point = getattr(framework, entry_point_name)

        _modules.write_env_vars(env.to_env_vars())

        entry_point()

        logger.info('Reporting training SUCCESS')
        _files.write_success_file()
        _exit_processes(SUCCESS_CODE)

    except _errors.ClientError as e:

        failure_message = str(e)
        _files.write_failure_file(failure_message)

        logger.error(failure_message)
        _exit_processes(DEFAULT_FAILURE_CODE)
    except Exception as e:
        failure_msg = 'framework error: \n%s\n%s' % (traceback.format_exc(),
                                                     str(e))

        _files.write_failure_file(failure_msg)
        logger.error('Reporting training FAILURE')

        logger.error(failure_msg)

        exit_code = getattr(e, 'errno', DEFAULT_FAILURE_CODE)
        _exit_processes(exit_code)
def train():
    intermediate_sync = None
    exit_code = SUCCESS_CODE
    try:
        # TODO: iquintero - add error handling for ImportError to let the user know
        # if the framework module is not defined.
        env = sagemaker_containers.training_env()

        # TODO: [issue#144] There is a bug in the logic -
        # we need os.environ.get(_params.REGION_NAME_ENV)
        # in certain regions, but it is not going to be available unless
        # TrainingEnvironment has been initialized. It shouldn't be environment variable.
        region = os.environ.get('AWS_REGION', os.environ.get(_params.REGION_NAME_ENV))
        s3_endpoint_url = os.environ.get("S3_ENDPOINT_URL")
        intermediate_sync = _intermediate_output.start_sync(env.sagemaker_s3_output(), region, endpoint_url=s3_endpoint_url)

        if env.framework_module:
            framework_name, entry_point_name = env.framework_module.split(':')

            framework = importlib.import_module(framework_name)

            # the logger is configured after importing the framework library, allowing the framework to
            # configure logging at import time.
            _logging.configure_logger(env.log_level)
            logger.info('Imported framework %s', framework_name)

            entrypoint = getattr(framework, entry_point_name)
            entrypoint()
        else:
            _logging.configure_logger(env.log_level)

            mpi_enabled = env.additional_framework_parameters.get(_params.MPI_ENABLED)
            runner_type = _runner.RunnerType.MPI if mpi_enabled else _runner.RunnerType.Process

            entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(),
                            env.to_env_vars(), runner=runner_type)

        logger.info('Reporting training SUCCESS')

        _files.write_success_file()
    except _errors.ClientError as e:

        failure_message = str(e)
        _files.write_failure_file(failure_message)

        logger.error(failure_message)

        if intermediate_sync:
            intermediate_sync.join()

        exit_code = DEFAULT_FAILURE_CODE
    except Exception as e:
        failure_msg = 'framework error: \n%s\n%s' % (traceback.format_exc(), str(e))

        _files.write_failure_file(failure_msg)
        logger.error('Reporting training FAILURE')

        logger.error(failure_msg)

        exit_code = getattr(e, 'errno', DEFAULT_FAILURE_CODE)
    finally:
        if intermediate_sync:
            intermediate_sync.join()

        _exit_processes(exit_code)
def train():
    """Placeholder docstring"""
    intermediate_sync = None
    exit_code = SUCCESS_CODE
    try:
        env = sagemaker_containers.training_env()

        region = os.environ.get("AWS_REGION",
                                os.environ.get(_params.REGION_NAME_ENV))
        s3_endpoint_url = os.environ.get(_params.S3_ENDPOINT_URL, None)
        intermediate_sync = _intermediate_output.start_sync(
            env.sagemaker_s3_output(), region, endpoint_url=s3_endpoint_url)

        if env.framework_module:
            framework_name, entry_point_name = env.framework_module.split(":")

            framework = importlib.import_module(framework_name)

            # the logger is configured after importing the framework library, allowing
            # the framework to configure logging at import time.
            _logging.configure_logger(env.log_level)
            logger.info("Imported framework %s", framework_name)

            entrypoint = getattr(framework, entry_point_name)
            entrypoint()
        else:
            _logging.configure_logger(env.log_level)

            mpi_enabled = env.additional_framework_parameters.get(
                _params.MPI_ENABLED)
            runner_type = _runner.RunnerType.MPI if mpi_enabled else _runner.RunnerType.Process

            entry_point.run(
                env.module_dir,
                env.user_entry_point,
                env.to_cmd_args(),
                env.to_env_vars(),
                runner=runner_type,
            )

        logger.info("Reporting training SUCCESS")

        _files.write_success_file()
    except _errors.ClientError as e:

        failure_message = str(e)
        _files.write_failure_file(failure_message)

        logger.error(failure_message)

        if intermediate_sync:
            intermediate_sync.join()

        exit_code = DEFAULT_FAILURE_CODE
    except Exception as e:  # pylint: disable=broad-except
        failure_msg = "framework error: \n%s\n%s" % (traceback.format_exc(),
                                                     str(e))

        _files.write_failure_file(failure_msg)
        logger.error("Reporting training FAILURE")

        logger.error(failure_msg)

        error_number = getattr(e, "errno", DEFAULT_FAILURE_CODE)
        exit_code = _get_valid_failure_exit_code(error_number)
    finally:
        if intermediate_sync:
            intermediate_sync.join()

        _exit_processes(exit_code)