示例#1
0
 def _evaluate_iterations_per_loop_in_seconds(self, value, expected_value,
                                              expected_unit):
     config = tpu_config_lib.RunConfig(tpu_config=tpu_config_lib.TPUConfig(
         iterations_per_loop=value))
     self.assertEqual(config.tpu_config.iterations_per_loop, value)
     d = util_lib.parse_iterations_per_loop(
         config.tpu_config.iterations_per_loop)
     self.assertEqual(expected_value, d.value)
     self.assertEqual(expected_unit, d.unit)
示例#2
0
    def __new__(
            cls,
            iterations_per_loop=2,
            num_shards=None,
            num_cores_per_replica=None,
            per_host_input_for_training=True,
            tpu_job_name=None,
            initial_infeed_sleep_secs=None,
            input_partition_dims=None,
            eval_training_input_configuration=InputPipelineConfig.PER_HOST_V1,
            experimental_host_call_every_n_steps=1):

        # Check iterations_per_loop.
        util_lib.parse_iterations_per_loop(iterations_per_loop)

        # Check num_shards.
        if num_shards is not None:
            util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')

        if input_partition_dims is not None:
            if len(input_partition_dims) != 1 and len(
                    input_partition_dims) != 2:
                raise ValueError(
                    'input_partition_dims must be a list/tuple with one or two'
                    ' elements.')

            if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2:
                raise ValueError(
                    'input_partition_dims is only supported in PER_HOST_V2 mode.'
                )

            if num_cores_per_replica is None:
                raise ValueError(
                    'input_partition_dims requires setting num_cores_per_replica.'
                )

        # Check num_cores_per_replica
        if num_cores_per_replica is not None:
            if num_cores_per_replica not in ([1, 2, 4, 8, 16, 32, 64, 128]):
                raise ValueError(
                    'num_cores_per_replica must be 1, 2, 4, 8, 16, 32, 64, 128; '
                    'got {}'.format(str(num_cores_per_replica)))

        if eval_training_input_configuration not in [
                InputPipelineConfig.PER_HOST_V1, InputPipelineConfig.SLICED
        ]:
            raise ValueError(
                'eval_training_input_configuration must be PER_HOST_V1 or SLICED;'
                ' got {}'.format(str(eval_training_input_configuration)))

        # per_host_input_for_training may be True, False, or integer in [1..3].
        # Map legacy values (True, False) to numeric values.
        if per_host_input_for_training is False:
            per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1
        elif per_host_input_for_training is True:
            per_host_input_for_training = InputPipelineConfig.PER_HOST_V1

        # Check initial_infeed_sleep_secs.
        if initial_infeed_sleep_secs:
            util_lib.check_positive_integer(
                initial_infeed_sleep_secs,
                'TPUConfig initial_infeed_sleep_secs')

        tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config()

        return super(TPUConfig, cls).__new__(
            cls,
            iterations_per_loop=iterations_per_loop,
            num_shards=num_shards,
            num_cores_per_replica=num_cores_per_replica,
            per_host_input_for_training=per_host_input_for_training,
            tpu_job_name=tpu_job_name,
            initial_infeed_sleep_secs=initial_infeed_sleep_secs,
            input_partition_dims=input_partition_dims,
            eval_training_input_configuration=eval_training_input_configuration,
            experimental_host_call_every_n_steps=
            experimental_host_call_every_n_steps)
示例#3
0
 def _parse_and_validate_iterations_per_loop(self, value, expected_value,
                                             expected_unit):
     d = util_lib.parse_iterations_per_loop(value)
     self.assertTrue(d)
     self.assertEqual(d.value, expected_value)
     self.assertEqual(d.unit, expected_unit)