Пример #1
0
 def test_validating_dataset_input_tensors_with_dtype_mismatch(self):
     with self.cached_session():
         strategy = mirrored_strategy.MirroredStrategy(
             ['/device:GPU:0', '/device:CPU:0'])
         a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32)
         b = constant_op.constant([1, 2],
                                  shape=(1, 2),
                                  dtype=dtypes.float64)
         x = values.DistributedValues({
             '/device:CPU:0': a,
             '/device:GPU:0': b
         })
         y = values.DistributedValues({
             '/device:CPU:0': a,
             '/device:GPU:0': a
         })
         with strategy.scope():
             # Removed device and input tensor dtype details from the error message
             # since the order of the device and the corresponding input tensor dtype
             # is not deterministic over different runs.
             with self.assertRaisesRegexp(
                     ValueError, 'Input tensor dtypes do not match for '
                     'distributed tensor inputs '
                     'DistributedValues:.+'):
                 distributed_training_utils.validate_distributed_dataset_inputs(
                     strategy, x, y)
Пример #2
0
 def test_validating_dataset_input_tensors_with_dtype_mismatch(self,
                                                               distribution):
   with self.cached_session():
     a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32)
     b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64)
     x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b})
     y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a})
     with distribution.scope():
       # Removed device and input tensor dtype details from the error message
       # since the order of the device and the corresponding input tensor dtype
       # is not deterministic over different runs.
       with self.assertRaisesRegexp(ValueError,
                                    'Input tensor dtypes do not match for '
                                    'distributed tensor inputs '
                                    'DistributedValues:.+'):
         distributed_training_utils.validate_distributed_dataset_inputs(
             distribution, x, y)
Пример #3
0
 def test_validating_dataset_input_tensors_with_shape_mismatch(self):
   with self.test_session():
     strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
                                                    '/device:CPU:0'])
     a = constant_op.constant([1, 2], shape=(1, 2))
     b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2))
     x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b})
     y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a})
     with strategy.scope():
       # Removed device and input tensor shape details from the error message
       # since the order of the device and the corresponding input tensor shape
       # is not deterministic over different runs.
       with self.assertRaisesRegexp(ValueError,
                                    'Input tensor shapes do not match for '
                                    'distributed tensor inputs '
                                    'DistributedValues:.+'):
         distributed_training_utils.validate_distributed_dataset_inputs(
             strategy, x, y)
Пример #4
0
 def test_validating_dataset_input_tensors_with_dtype_mismatch(self,
                                                               distribution):
   with self.cached_session():
     a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32)
     b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64)
     device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0'))
     x = values.DistributedValues(device_map, (a, b))
     y = values.DistributedValues(device_map, (a, a))
     # Removed device and input tensor dtype details from the error message
     # since the order of the device and the corresponding input tensor dtype
     # is not deterministic over different runs.
     with self.assertRaisesRegexp(ValueError,
                                  'Input tensor dtypes do not match for '
                                  'distributed tensor inputs '
                                  'DistributedValues:.+'):
       with distribution.scope():
         distributed_training_utils.validate_distributed_dataset_inputs(
             distribution, x, y)
Пример #5
0
 def test_validating_dataset_input_tensors_with_shape_mismatch(
     self, distribution):
   with self.cached_session():
     a = constant_op.constant([1, 2], shape=(1, 2))
     b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2))
     device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0'))
     x = values.DistributedValues(device_map, (a, b))
     y = values.DistributedValues(device_map, (a, a))
     # Removed device and input tensor shape details from the error message
     # since the order of the device and the corresponding input tensor shape
     # is not deterministic over different runs.
     with self.assertRaisesRegexp(
         ValueError, 'Input tensor shapes do not match for '
         'distributed tensor inputs '
         'DistributedValues:.+'):
       with distribution.scope():
         distributed_training_utils.validate_distributed_dataset_inputs(
             distribution, x, y)
Пример #6
0
def _get_input_from_iterator(iterator, model):
    """Get elements from the iterator and verify the input shape and type."""
    next_element = iterator.get_next()

    if len(nest.flatten(next_element)) == len(model.inputs):
        x = next_element
        y = None
        sample_weights = None
    elif len(nest.flatten(next_element)) == (len(model.inputs) +
                                             len(model.outputs)):
        x, y = next_element
        sample_weights = None
    else:
        x, y, sample_weights = next_element

    # Validate that all the elements in x and y are of the same type and shape.
    # We can then pass the first element of x and y to `_standardize_weights`
    # below and be confident of the output.
    distributed_training_utils.validate_distributed_dataset_inputs(
        model._distribution_strategy, x, y, sample_weights)
    return x, y, sample_weights