def test_validating_dataset_input_tensors_with_dtype_mismatch( self, distribution): with self.cached_session(): @tf.function def run(): ctx = tf.distribute.get_replica_context() if ctx.replica_id_in_sync_group.device.endswith("GPU:0"): return tf.constant([[1, 2]], dtype=tf.int32) else: return tf.constant([[1, 2]], dtype=tf.float64) x = distribution.run(run) # Removed device and input tensor dtype details from the error message # since the order of the device and the corresponding input tensor dtype # is not deterministic over different runs. with self.assertRaisesRegex( ValueError, "Input tensor dtypes do not match for " "distributed tensor inputs " "PerReplica:.+", ): with distribution.scope(): distributed_training_utils_v1.validate_distributed_dataset_inputs( distribution, x, None)
def test_validating_dataset_input_tensors_with_dtype_mismatch( self, distribution): with self.cached_session(): a = tf.constant([1, 2], shape=(1, 2), dtype=tf.int32) b = tf.constant([1, 2], shape=(1, 2), dtype=tf.float64) x = tf.distribute.DistributedValues((a, b)) y = tf.distribute.DistributedValues((a, a)) # Removed device and input tensor dtype details from the error message # since the order of the device and the corresponding input tensor dtype # is not deterministic over different runs. with self.assertRaisesRegex( ValueError, 'Input tensor dtypes do not match for ' 'distributed tensor inputs ' 'DistributedValues:.+'): with distribution.scope(): distributed_training_utils_v1.validate_distributed_dataset_inputs( distribution, x, y)