Ejemplo n.º 1
0
    def test_validating_dataset_input_tensors_with_dtype_mismatch(
            self, distribution):
        with self.cached_session():

            @tf.function
            def run():
                ctx = tf.distribute.get_replica_context()
                if ctx.replica_id_in_sync_group.device.endswith("GPU:0"):
                    return tf.constant([[1, 2]], dtype=tf.int32)
                else:
                    return tf.constant([[1, 2]], dtype=tf.float64)

            x = distribution.run(run)

            # Removed device and input tensor dtype details from the error message
            # since the order of the device and the corresponding input tensor dtype
            # is not deterministic over different runs.
            with self.assertRaisesRegex(
                    ValueError,
                    "Input tensor dtypes do not match for "
                    "distributed tensor inputs "
                    "PerReplica:.+",
            ):
                with distribution.scope():
                    distributed_training_utils_v1.validate_distributed_dataset_inputs(
                        distribution, x, None)
Ejemplo n.º 2
0
 def test_validating_dataset_input_tensors_with_dtype_mismatch(
         self, distribution):
     with self.cached_session():
         a = tf.constant([1, 2], shape=(1, 2), dtype=tf.int32)
         b = tf.constant([1, 2], shape=(1, 2), dtype=tf.float64)
         x = tf.distribute.DistributedValues((a, b))
         y = tf.distribute.DistributedValues((a, a))
         # Removed device and input tensor dtype details from the error message
         # since the order of the device and the corresponding input tensor dtype
         # is not deterministic over different runs.
         with self.assertRaisesRegex(
                 ValueError, 'Input tensor dtypes do not match for '
                 'distributed tensor inputs '
                 'DistributedValues:.+'):
             with distribution.scope():
                 distributed_training_utils_v1.validate_distributed_dataset_inputs(
                     distribution, x, y)