def test_validating_dataset_input_tensors_with_dtype_mismatch(self): with self.cached_session(): strategy = mirrored_strategy.MirroredStrategy( ['/device:GPU:0', '/device:CPU:0']) a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) x = values.DistributedValues({ '/device:CPU:0': a, '/device:GPU:0': b }) y = values.DistributedValues({ '/device:CPU:0': a, '/device:GPU:0': a }) with strategy.scope(): # Removed device and input tensor dtype details from the error message # since the order of the device and the corresponding input tensor dtype # is not deterministic over different runs. with self.assertRaisesRegexp( ValueError, 'Input tensor dtypes do not match for ' 'distributed tensor inputs ' 'DistributedValues:.+'): distributed_training_utils.validate_distributed_dataset_inputs( strategy, x, y)
def test_validating_dataset_input_tensors_with_dtype_mismatch(self, distribution): with self.cached_session(): a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) with distribution.scope(): # Removed device and input tensor dtype details from the error message # since the order of the device and the corresponding input tensor dtype # is not deterministic over different runs. with self.assertRaisesRegexp(ValueError, 'Input tensor dtypes do not match for ' 'distributed tensor inputs ' 'DistributedValues:.+'): distributed_training_utils.validate_distributed_dataset_inputs( distribution, x, y)
def test_validating_dataset_input_tensors_with_shape_mismatch(self): with self.test_session(): strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', '/device:CPU:0']) a = constant_op.constant([1, 2], shape=(1, 2)) b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) with strategy.scope(): # Removed device and input tensor shape details from the error message # since the order of the device and the corresponding input tensor shape # is not deterministic over different runs. with self.assertRaisesRegexp(ValueError, 'Input tensor shapes do not match for ' 'distributed tensor inputs ' 'DistributedValues:.+'): distributed_training_utils.validate_distributed_dataset_inputs( strategy, x, y)
def test_validating_dataset_input_tensors_with_dtype_mismatch(self, distribution): with self.cached_session(): a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) x = values.DistributedValues(device_map, (a, b)) y = values.DistributedValues(device_map, (a, a)) # Removed device and input tensor dtype details from the error message # since the order of the device and the corresponding input tensor dtype # is not deterministic over different runs. with self.assertRaisesRegexp(ValueError, 'Input tensor dtypes do not match for ' 'distributed tensor inputs ' 'DistributedValues:.+'): with distribution.scope(): distributed_training_utils.validate_distributed_dataset_inputs( distribution, x, y)
def test_validating_dataset_input_tensors_with_shape_mismatch( self, distribution): with self.cached_session(): a = constant_op.constant([1, 2], shape=(1, 2)) b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) x = values.DistributedValues(device_map, (a, b)) y = values.DistributedValues(device_map, (a, a)) # Removed device and input tensor shape details from the error message # since the order of the device and the corresponding input tensor shape # is not deterministic over different runs. with self.assertRaisesRegexp( ValueError, 'Input tensor shapes do not match for ' 'distributed tensor inputs ' 'DistributedValues:.+'): with distribution.scope(): distributed_training_utils.validate_distributed_dataset_inputs( distribution, x, y)
def _get_input_from_iterator(iterator, model): """Get elements from the iterator and verify the input shape and type.""" next_element = iterator.get_next() if len(nest.flatten(next_element)) == len(model.inputs): x = next_element y = None sample_weights = None elif len(nest.flatten(next_element)) == (len(model.inputs) + len(model.outputs)): x, y = next_element sample_weights = None else: x, y, sample_weights = next_element # Validate that all the elements in x and y are of the same type and shape. # We can then pass the first element of x and y to `_standardize_weights` # below and be confident of the output. distributed_training_utils.validate_distributed_dataset_inputs( model._distribution_strategy, x, y, sample_weights) return x, y, sample_weights