コード例 #1
0
def _get_input_from_iterator(iterator, with_batch_index=False):
    """Get elements from the iterator and verify the input shape and type."""
    next_element = next(iterator)
    if with_batch_index:
        batch_index, next_element = next_element
    else:
        batch_index = None

    if (tensor_util.is_tensor(next_element)
            or isinstance(next_element,
                          (dict, composite_tensor.CompositeTensor))):
        next_element = [next_element]
    if len(next_element) == 1:
        x, = next_element
        y = None
        sample_weights = None
    elif len(next_element) == 2:
        x, y = next_element
        sample_weights = None
    else:
        x, y, sample_weights = next_element

    # Validate that all the elements in x and y are of the same type and shape.
    dist_utils.validate_distributed_dataset_inputs(
        distribution_strategy_context.get_strategy(), x, y, sample_weights)
    return x, y, sample_weights, batch_index
コード例 #2
0
 def test_validating_dataset_input_tensors_with_dtype_mismatch(
     self, distribution):
   with self.cached_session():
     a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32)
     b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64)
     x = values.DistributedValues((a, b))
     y = values.DistributedValues((a, a))
     # Removed device and input tensor dtype details from the error message
     # since the order of the device and the corresponding input tensor dtype
     # is not deterministic over different runs.
     with self.assertRaisesRegexp(
         ValueError, 'Input tensor dtypes do not match for '
         'distributed tensor inputs '
         'DistributedValues:.+'):
       with distribution.scope():
         distributed_training_utils.validate_distributed_dataset_inputs(
             distribution, x, y)
コード例 #3
0
 def test_validating_dataset_input_tensors_with_dtype_mismatch(
     self, distribution):
   with self.cached_session():
     a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32)
     b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64)
     device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0'))
     x = values.DistributedValues(device_map, (a, b))
     y = values.DistributedValues(device_map, (a, a))
     # Removed device and input tensor dtype details from the error message
     # since the order of the device and the corresponding input tensor dtype
     # is not deterministic over different runs.
     with self.assertRaisesRegexp(
         ValueError, 'Input tensor dtypes do not match for '
         'distributed tensor inputs '
         'DistributedValues:.+'):
       with distribution.scope():
         distributed_training_utils.validate_distributed_dataset_inputs(
             distribution, x, y)