def _get_input_from_iterator(iterator, with_batch_index=False): """Get elements from the iterator and verify the input shape and type.""" next_element = next(iterator) if with_batch_index: batch_index, next_element = next_element else: batch_index = None if (tensor_util.is_tensor(next_element) or isinstance(next_element, (dict, composite_tensor.CompositeTensor))): next_element = [next_element] if len(next_element) == 1: x, = next_element y = None sample_weights = None elif len(next_element) == 2: x, y = next_element sample_weights = None else: x, y, sample_weights = next_element # Validate that all the elements in x and y are of the same type and shape. dist_utils.validate_distributed_dataset_inputs( distribution_strategy_context.get_strategy(), x, y, sample_weights) return x, y, sample_weights, batch_index
def test_validating_dataset_input_tensors_with_dtype_mismatch( self, distribution): with self.cached_session(): a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) x = values.DistributedValues((a, b)) y = values.DistributedValues((a, a)) # Removed device and input tensor dtype details from the error message # since the order of the device and the corresponding input tensor dtype # is not deterministic over different runs. with self.assertRaisesRegexp( ValueError, 'Input tensor dtypes do not match for ' 'distributed tensor inputs ' 'DistributedValues:.+'): with distribution.scope(): distributed_training_utils.validate_distributed_dataset_inputs( distribution, x, y)
def test_validating_dataset_input_tensors_with_dtype_mismatch( self, distribution): with self.cached_session(): a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) x = values.DistributedValues(device_map, (a, b)) y = values.DistributedValues(device_map, (a, a)) # Removed device and input tensor dtype details from the error message # since the order of the device and the corresponding input tensor dtype # is not deterministic over different runs. with self.assertRaisesRegexp( ValueError, 'Input tensor dtypes do not match for ' 'distributed tensor inputs ' 'DistributedValues:.+'): with distribution.scope(): distributed_training_utils.validate_distributed_dataset_inputs( distribution, x, y)