Example #1
0
def matching_args(fn, dictionary):
    """Given a function fn and a dict dictionary, returns the function
    arguments that match the dict keys.

    Example:

        def train(channel_dirs, model_dir): pass

        dictionary = {'channel_dirs': {}, 'model_dir': '/opt/ml/model', 'other_args': None}

        args = functions.matching_args(train, dictionary) # {'channel_dirs': {},
                                                             'model_dir': '/opt/ml/model'}

        train(**args)
    Args:
        fn (function): A function.
        dictionary (dict): The dictionary with the keys to compare against the
            function arguments.

    Returns:
        (dict) A dictionary with only matching arguments.
    """
    arg_spec = getargspec(fn)

    if arg_spec.keywords:
        return dictionary

    return mapping.split_by_criteria(dictionary, arg_spec.args).included
    def __init__(self,
                 resource_config=None,
                 input_data_config=None,
                 hyperparameters=None):
        """Initialize a read-only snapshot of the container environment.

        Args:
            resource_config (dict[string, object]): The contents from
                /opt/ml/input/config/resourceconfig.json.
                It has the following keys:
                    - current_host: The name of the current container on the container network.
                        For example, 'algo-1'.
                    -  hosts: The list of names of all containers on the container network,
                        sorted lexicographically. For example, `['algo-1', 'algo-2', 'algo-3']`
                        for a three-node cluster.

            input_data_config (dict[string, object]): The contents from /opt/ml/input/config/inputdataconfig.json.
                For example, suppose that you specify three data channels (train, evaluation, and
                validation) in your request. This dictionary will contain:

                {'train': {
                    'ContentType':  'trainingContentType',
                    'TrainingInputMode': 'File',
                    'S3DistributionType': 'FullyReplicated',
                    'RecordWrapperType': 'None'
                },
                'evaluation' : {
                    'ContentType': 'evalContentType',
                    'TrainingInputMode': 'File',
                    'S3DistributionType': 'FullyReplicated',
                    'RecordWrapperType': 'None'
                },
                'validation': {
                    'TrainingInputMode': 'File',
                    'S3DistributionType': 'FullyReplicated',
                    'RecordWrapperType': 'None'
                }}

                You can find more information about /opt/ml/input/config/inputdataconfig.json here:
                https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html#your-algorithms-training-algo-running-container-inputdataconfig

            hyperparameters (dict[string, object]): An instance of `HyperParameters` containing the
                training job hyperparameters.
        """
        current_host = os.environ.get(params.CURRENT_HOST_ENV)
        module_name = os.environ.get(params.USER_PROGRAM_ENV, None)
        module_dir = os.environ.get(params.SUBMIT_DIR_ENV, code_dir)
        log_level = int(os.environ.get(params.LOG_LEVEL_ENV, logging.INFO))

        self._current_host = current_host
        self._num_gpus = num_gpus()
        self._num_cpus = num_cpus()
        self._module_name = module_name
        self._user_entry_point = module_name
        self._module_dir = module_dir
        self._log_level = log_level
        self._model_dir = model_dir

        resource_config = resource_config or read_resource_config()
        input_data_config = input_data_config or read_input_data_config()
        all_hyperparameters = hyperparameters or read_hyperparameters()

        current_host = resource_config["current_host"]
        hosts = resource_config["hosts"]

        split_result = mapping.split_by_criteria(
            all_hyperparameters,
            keys=params.SAGEMAKER_HYPERPARAMETERS,
            prefix=params.SAGEMAKER_PREFIX,
        )
        sagemaker_hyperparameters = split_result.included

        additional_framework_parameters = {
            k: sagemaker_hyperparameters[k]
            for k in sagemaker_hyperparameters.keys()
            if k not in params.SAGEMAKER_HYPERPARAMETERS
        }

        sagemaker_region = sagemaker_hyperparameters.get(
            params.REGION_NAME_PARAM,
            boto3.session.Session().region_name)

        os.environ[params.JOB_NAME_ENV] = sagemaker_hyperparameters.get(
            params.JOB_NAME_PARAM, "")
        os.environ[params.CURRENT_HOST_ENV] = current_host
        os.environ[params.REGION_NAME_ENV] = sagemaker_region or ""

        self._hosts = hosts

        # eth0 is the default network interface defined by SageMaker with VPC support and
        # local mode.
        # ethwe is the current network interface defined by SageMaker training, it will be
        # changed to eth0 in the short future.
        self._network_interface_name = resource_config.get(
            "network_interface_name", "eth0")

        self._hyperparameters = split_result.excluded
        self._additional_framework_parameters = additional_framework_parameters
        self._resource_config = resource_config
        self._input_data_config = input_data_config
        self._output_data_dir = output_data_dir
        self._output_intermediate_dir = output_intermediate_dir
        self._channel_input_dirs = {
            channel: channel_path(channel)
            for channel in input_data_config
        }
        self._current_host = current_host

        # override base class attributes
        if self._module_name is None:
            self._module_name = str(
                sagemaker_hyperparameters.get(params.USER_PROGRAM_PARAM, None))
        self._user_entry_point = self._user_entry_point or sagemaker_hyperparameters.get(
            params.USER_PROGRAM_PARAM)

        self._module_dir = str(
            sagemaker_hyperparameters.get(params.SUBMIT_DIR_PARAM, code_dir))
        self._log_level = sagemaker_hyperparameters.get(
            params.LOG_LEVEL_PARAM, logging.INFO)
        self._sagemaker_s3_output = sagemaker_hyperparameters.get(
            params.S3_OUTPUT_LOCATION_PARAM, None)
        self._framework_module = os.environ.get(
            params.FRAMEWORK_TRAINING_MODULE_ENV, None)

        self._input_dir = input_dir
        self._input_config_dir = input_config_dir
        self._output_dir = output_dir
        self._job_name = os.environ.get(params.TRAINING_JOB_ENV.upper(), None)

        self._master_hostname = list(hosts)[0]
        self._is_master = current_host == self._master_hostname
Example #3
0
def test_split_by_criteria_with_keys_and_criteria(dictionary, keys, prefix,
                                                  expected):
    assert mapping.split_by_criteria(dictionary, keys=keys,
                                     prefix=prefix) == expected
Example #4
0
def test_split_by_criteria_with_prefix(dictionary, prefix, expected):
    assert mapping.split_by_criteria(dictionary, prefix=prefix) == expected
Example #5
0
def test_split_by_criteria_with_keys(dictionary, keys, expected):
    assert mapping.split_by_criteria(dictionary, keys=keys) == expected