def test_strategy(self, strategy): if ( backend.is_tpu_strategy(strategy) and not tf_test_utils.is_mlir_bridge_enabled() ): self.skipTest("TPU tests require MLIR bridge") input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]]) inp_dataset = tf.data.Dataset.from_tensor_slices(input_array) inp_dataset = batch_wrapper(inp_dataset, 2, strategy) # pyformat: disable expected_output = [[0, 1, 1, 1, 0, 0], [1, 1, 0, 1, 0, 0]] # pyformat: enable num_tokens = 6 tf.config.set_soft_device_placement(True) with strategy.scope(): input_data = keras.Input(shape=(4,), dtype=tf.int32) layer = category_encoding.CategoryEncoding( num_tokens=num_tokens, output_mode=category_encoding.MULTI_HOT ) int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) output_dataset = model.predict(inp_dataset) self.assertAllEqual(expected_output, output_dataset)
def test_strategy_with_file(self, strategy): if (backend.is_tpu_strategy(strategy) and not tf_test_utils.is_mlir_bridge_enabled()): self.skipTest("TPU tests require MLIR bridge") vocab_data = ["earth", "wind", "and", "fire"] vocab_file = self._write_to_temp_file("temp", vocab_data) input_array = np.array([["earth", "wind", "and", "fire"], ["fire", "and", "earth", "michigan"]]) input_dataset = tf.data.Dataset.from_tensor_slices(input_array).batch( 2, drop_remainder=True) expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]] tf.config.set_soft_device_placement(True) with strategy.scope(): input_data = keras.Input(shape=(None,), dtype=tf.string) layer = index_lookup.IndexLookup( max_tokens=None, num_oov_indices=1, mask_token="", oov_token="[OOV]", vocabulary_dtype=tf.string, vocabulary=vocab_file) int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) model.compile(loss="mse") output_dataset = model.predict(input_dataset) self.assertAllEqual(expected_output, output_dataset)
def test_distribution_strategy_output(self, strategy): if (backend.is_tpu_strategy(strategy) and not tf_test_utils.is_mlir_bridge_enabled()): self.skipTest("TPU tests require MLIR bridge") vocab_data = ["earth", "wind", "and", "fire"] input_array = np.array([["earth", "wind", "and", "fire"], ["fire", "and", "earth", "michigan"]]) input_dataset = tf.data.Dataset.from_tensor_slices(input_array).batch( 2, drop_remainder=True) expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]] tf.config.set_soft_device_placement(True) with strategy.scope(): input_data = keras.Input(shape=(None, ), dtype=tf.string) layer = text_vectorization.TextVectorization( max_tokens=None, standardize=None, split=None, output_mode=text_vectorization.INT, vocabulary=vocab_data) int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) output_dataset = model.predict(input_dataset) self.assertAllEqual(expected_output, output_dataset)
def testV2SummaryWithKerasSubclassedModel(self): # Histogram summaries require the MLIR bridge; see b/178826597#comment107. # TODO(https://github.com/tensorflow/tensorboard/issues/2885): remove this # if histogram summaries are supported fully on non-MLIR bridge or # non-MLIR bridge is no longer run. enable_histograms = test_util.is_mlir_bridge_enabled() strategy = get_tpu_strategy() with strategy.scope(): model = CustomModel(enable_histograms=enable_histograms) model.compile('sgd', 'mse') dataset = distribute_strategy_test.get_dataset(strategy) tensorboard_callback = callbacks.TensorBoard( self.summary_dir, update_freq=2) model.fit( dataset, steps_per_epoch=10, epochs=1, callbacks=[tensorboard_callback]) event_files = tf.io.gfile.glob( os.path.join(self.summary_dir, 'train', 'event*')) # Since total of 10 steps are ran and summary ops should be invoked # every 2 batches, we should see total of 5 event logs for each summary. expected_event_counts = { ('custom_model/layer_for_scalar_summary/' 'custom_scalar_summary_v2'): 5, ('custom_model/layer_for_histogram_summary/' 'custom_histogram_summary_v2'): 5 if enable_histograms else 0, } self.validate_recorded_sumary_file(event_files, expected_event_counts)
def testV2SummaryWithKerasSequentialModel(self): # Histogram summaries require the MLIR bridge; see b/178826597#comment107. # TODO(https://github.com/tensorflow/tensorboard/issues/2885): remove this # if histogram summaries are supported fully on non-MLIR bridge or # non-MLIR bridge is no longer run. enable_histograms = tf_test_utils.is_mlir_bridge_enabled() strategy = get_tpu_strategy() with strategy.scope(): model = mnist_model((28, 28, 3), enable_histograms=enable_histograms) model.compile("sgd", "mse") dataset = get_image_dataset() tensorboard_callback = callbacks.TensorBoard(self.summary_dir, update_freq=2) model.fit( dataset, steps_per_epoch=10, epochs=1, callbacks=[tensorboard_callback], ) event_files = tf.io.gfile.glob( os.path.join(self.summary_dir, "train", "event*")) # Since total of 10 steps are ran and summary ops should be invoked # every 2 batches, we should see total of 5 event logs for each summary. expected_event_counts = { "sequential/layer_for_histogram_summary/custom_histogram_summary_v2": 5 if enable_histograms else 0, "sequential/layer_for_image_summary/custom_image_summary_v2": 5, } self.validate_recorded_sumary_file(event_files, expected_event_counts)
def test_spmd_with_summary(self): if test_util.is_mlir_bridge_enabled(): self.skipTest("TODO(b/232580663): fix MLIR bridge") original_device_placement = config.get_soft_device_placement() config.set_soft_device_placement(True) strategy, _ = get_tpu_strategy(enable_spmd=True) summary_dir = self.get_temp_dir() writer = summary_ops.create_file_writer_v2(summary_dir) with strategy.scope(): step = variables.Variable(0, dtype=dtypes.int64) @def_function.function def run(): with writer.as_default(): summary_ops.scalar("result", step * 2, step=step) step.assign_add(1) for _ in range(10): strategy.run(run, args=()) for val in step.values: for var in val.variables: self.assertAllEqual(10, var) config.set_soft_device_placement(original_device_placement)
def test_paritioned_model_checkpointing(self): if test_util.is_mlir_bridge_enabled(): self.skipTest("TODO(b/238811067): fix MLIR bridge") class PartitionedModel(module.Module): def __init__(self, v, w): super(PartitionedModel, self).__init__() assert distribution_strategy_context.has_strategy() strategy = distribution_strategy_context.get_strategy() with strategy.extended.experimental_logical_device(0): self.v = variables.Variable(v) with strategy.extended.experimental_logical_device(1): self.w = variables.Variable(w) def __call__(self, x): replica_ctx = distribution_strategy_context.get_replica_context() with replica_ctx.experimental_logical_device(0): y = self.v * x with replica_ctx.experimental_logical_device(1): z = self.w * y return z def change_weights_op(self, v_new, w_new): return control_flow_ops.group( [self.v.assign(v_new), self.w.assign(w_new)]) strategy, num_replicas = get_tpu_strategy() with strategy.scope(): model = PartitionedModel(2., 3.) checkpoint_dir = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") checkpoint = util.Checkpoint(model=model) with self.cached_session() as sess: self.evaluate(variables.global_variables_initializer()) checkpoint.save(file_prefix=checkpoint_prefix) self.evaluate(model.change_weights_op(1., 4.)) result = strategy.run(def_function.function(model), args=(5.0,)) self.assertEqual(20. * num_replicas, self.evaluate(strategy.reduce("SUM", result, axis=None))) status = checkpoint.restore( checkpoint_management.latest_checkpoint(checkpoint_dir)) status.run_restore_ops(sess) # must run restore op in non-eager mode. status.assert_consumed() status.assert_existing_objects_matched() result = strategy.run(def_function.function(model), args=(5.0,)) self.assertEqual(30. * num_replicas, self.evaluate(strategy.reduce("SUM", result, axis=None)))
def testUnsupportedOps(self): with ops.device('device:{}:0'.format(self.device)): def fn(x): return string_ops.string_length( string_ops.string_format('{}', x)) xla_func = def_function.function(fn, jit_compile=True) with self.assertRaisesRegex( errors.InvalidArgumentError, 'legalization failed' if test_util.is_mlir_bridge_enabled() else 'unsupported operations'): xla_func(constant_op.constant([3.1, 3.2]))
def testMethodCompilationUnsupportedFunc(self): with ops.device('device:{}:0'.format(self.device)): class C(object): @def_function.function(jit_compile=True) def f1(self, x): return string_ops.string_length( string_ops.string_format('{}', x)) inputs = constant_op.constant([1, 2, 2, 3, 3]) c = C() with self.assertRaisesRegex( errors.InvalidArgumentError, 'legalization failed' if test_util.is_mlir_bridge_enabled() else 'unsupported operations'): c.f1(inputs)
def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs, index_dtype, rtol=1e-4, atol=1e-4): """Tests that the output of 'tf_reduce_fn' matches numpy's output.""" for test_input in test_inputs: with self.session() as sess: with self.test_scope(): a = array_ops.placeholder(dtype) index = array_ops.placeholder(index_dtype) out = tf_reduce_fn(a, index) result = sess.run(out, {a: test_input, index: [0]}) self.assertAllClose(result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol) result = sess.run(out, {a: test_input, index: [1]}) self.assertAllClose(result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol) result = sess.run(out, {a: test_input, index: [-1]}) self.assertAllClose(result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol) # MLIR bridge doesn't return the same error so it can't be matched # directly. if not test_util.is_mlir_bridge_enabled(): with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, 'Invalid reduction dim'): sess.run(out, {a: test_input, index: [-33]}) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, 'Invalid reduction dim'): sess.run(out, {a: test_input, index: [2]})
def testNestedCallUnsupportedOps(self): with ops.device('device:{}:0'.format(self.device)): def fn(x): return array_ops.unique(x).y xla_func = def_function.function(fn, jit_compile=True) def fn2(x): return xla_func(x) func = def_function.function(fn2, jit_compile=False) inputs = constant_op.constant([1, 2, 2, 3, 3]) with self.assertRaisesRegex( errors.InvalidArgumentError, 'legalization failed' if test_util.is_mlir_bridge_enabled() else 'not compilable'): func(inputs)
def testUnsupportedOps(self): if 'tpu' in self.device.lower(): self.skipTest('XLA TPU supports tf.unique') with ops.device('device:{}:0'.format(self.device)): def fn(x): return array_ops.unique(x).y # Unique is not supported by XLA func = def_function.function(fn, jit_compile=False) xla_func = def_function.function(fn, jit_compile=True) inputs = constant_op.constant([1, 2, 2, 3, 3]) self.assertAllClose([1, 2, 3], func(inputs)) with self.assertRaisesRegex( errors.InvalidArgumentError, 'legalization failed' if test_util.is_mlir_bridge_enabled() else 'unsupported operations'): xla_func(inputs)
def testMethodCompilationUnsupportedFunc(self): if 'tpu' in self.device.lower(): self.skipTest('XLA TPU supports tf.unique') with ops.device('device:{}:0'.format(self.device)): class C(object): @def_function.function(jit_compile=True) def f1(self, x): return array_ops.unique(x).y inputs = constant_op.constant([1, 2, 2, 3, 3]) c = C() with self.assertRaisesRegex( errors.InvalidArgumentError, 'legalization failed' if test_util.is_mlir_bridge_enabled() else 'unsupported operations'): c.f1(inputs)
def test_strategy(self, strategy): if (backend.is_tpu_strategy(strategy) and not tf_test_utils.is_mlir_bridge_enabled()): self.skipTest("TPU tests require MLIR bridge") input_data = np.asarray([["omar"], ["stringer"], ["marlo"], ["wire"]]) input_dataset = tf.data.Dataset.from_tensor_slices(input_data).batch( 2, drop_remainder=True) expected_output = [[0], [0], [1], [0]] tf.config.set_soft_device_placement(True) with strategy.scope(): input_data = keras.Input(shape=(None, ), dtype=tf.string) layer = hashing.Hashing(num_bins=2) int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) output_dataset = model.predict(input_dataset) self.assertAllEqual(expected_output, output_dataset)
def testCollectiveReduceGroupAssignment(self): if not test_util.is_mlir_bridge_enabled(): self.skipTest('AssignGroup is only supported in the MLIR bridge.') with ops.device('device:{}:0'.format(self.device)): @def_function.function(jit_compile=True) def fn(x): group_key = collective_ops.assign_group_v2( group_assignment=[[0]], device_index=0) t0 = collective_ops.all_reduce_v2(t=x, group_size=1, group_key=group_key, instance_key=1) return t0 inputs = constant_op.constant([1.0, 2.0, 3.0]) # Make sure 2 different channel ids are assigned to the 2 all-reduce # instructions generated by XLA. hlo_str = fn.experimental_get_compiler_ir(inputs)() self.assertIn('replica_groups={{0}}', hlo_str)
def testNestedCallUnsupportedOps(self): if 'tpu' in self.device.lower(): self.skipTest('Outside compilation will extract string_length to CPU') with ops.device('device:{}:0'.format(self.device)): def fn(x): return string_ops.string_length( string_ops.string_format('{}', x)) xla_func = def_function.function(fn, jit_compile=True) def fn2(x): return xla_func(x) func = def_function.function(fn2, jit_compile=False) inputs = constant_op.constant([1, 2, 2, 3, 3]) with self.assertRaisesRegex( errors.InvalidArgumentError, 'legalization failed' if test_util.is_mlir_bridge_enabled() else 'unsupported operations'): func(inputs)
def test_logical_device_assignment(self): if test_util.is_mlir_bridge_enabled(): self.skipTest("TODO(b/238811067): fix MLIR bridge") strategy, num_replicas = get_tpu_strategy() with strategy.scope(): v = variables.Variable(2.) with strategy.extended.experimental_logical_device(1): w = variables.Variable(3.) self.assertLen(strategy.experimental_local_results(v), num_replicas) self.assertLen(strategy.experimental_local_results(w), num_replicas) self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:0", strategy.experimental_local_results(v)[0].device) self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:1", strategy.experimental_local_results(w)[0].device) logical_devices = [] @def_function.function def f(x): replica_ctx = distribution_strategy_context.get_replica_context() with replica_ctx.experimental_logical_device(0): y = v * x with replica_ctx.experimental_logical_device(1): z = w * y logical_devices.append((y.device, z.device)) return z result = strategy.run(f, args=(5.,)) self.assertEqual( [("/device:TPU_REPLICATED_CORE:0", "/device:TPU_REPLICATED_CORE:1")], logical_devices) with self.cached_session(): self.evaluate(variables.global_variables_initializer()) self.assertEqual(30. * num_replicas, self.evaluate(strategy.reduce("SUM", result, axis=None)))
def testTensorArrayErrorMessage(self): with ops.device('device:{}:0'.format(self.device)): @def_function.function(jit_compile=True) def f(): # The error message as old and new bridge differ in which op they flag. # The one points to the creation of the unitialized tensor array, the # other is the use of the unitialized tensor array. ta = tensor_array_ops.TensorArray( # EXPECTED_MESSAGE_NEW dtype=dtypes.float32, size=2, dynamic_size=True, element_shape=(None, )) return ta.concat() # EXPECTED_MESSAGE_OLD if test_util.is_mlir_bridge_enabled(): with self.assertRaisesRegex(errors.InvalidArgumentError, 'EXPECTED_MESSAGE_NEW'): f() else: with self.assertRaisesRegex(errors.InvalidArgumentError, 'EXPECTED_MESSAGE_OLD'): f()
def __init__(self, method_name='runTest'): super(XLATestCase, self).__init__(method_name) context.context( ).enable_mlir_bridge = test_util.is_mlir_bridge_enabled() self.device = FLAGS.test_device self.has_custom_call = (self.device == 'XLA_CPU') self._all_tf_types = set([ dtypes.as_dtype(types_pb2.DataType.Value(name)) for name in FLAGS.types.split(',') ]) self.int_tf_types = set( [dtype for dtype in self._all_tf_types if dtype.is_integer]) self._float_tf_types = set( [dtype for dtype in self._all_tf_types if dtype.is_floating]) self.complex_tf_types = set( [dtype for dtype in self._all_tf_types if dtype.is_complex]) self._numeric_tf_types = set(self.int_tf_types | self._float_tf_types | self.complex_tf_types) self.quantized_tf_types = set(dtype for dtype in self._all_tf_types if dtype.is_quantized) # Quantized types don't have a numpy equivalent, include them in # all_tf_types but not in all_types. # TODO(b/115960798): Parametrize tests on TF types instead of numpy types # and remove all_types. self._all_types = set(dtype.as_numpy_dtype for dtype in self._all_tf_types if not dtype.is_quantized) self._int_types = set( [dtype.as_numpy_dtype for dtype in self.int_tf_types]) self.signed_int_types = set(dtype.as_numpy_dtype for dtype in self.int_tf_types if not dtype.is_unsigned) self.unsigned_int_types = set(dtype.as_numpy_dtype for dtype in self.int_tf_types if dtype.is_unsigned) self._float_types = set( [dtype.as_numpy_dtype for dtype in self._float_tf_types]) self.complex_types = set( [dtype.as_numpy_dtype for dtype in self.complex_tf_types]) self._numeric_types = set(self._int_types | self._float_types | self.complex_types) # Parse the manifest file, if any, into a regex identifying tests to # disable # TODO(xpan): Make it text proto if it doesn't scale. # Each line of the manifest file specifies an entry. The entry can be # 1) TestNameRegex // E.g. CumprodTest.* Or # 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16 # The 1) disables the entire test. While 2) only filter some numeric types # so that they are not used in those tests. self.disabled_regex = None self._method_types_filter = {} if FLAGS.disabled_manifest is not None: with open(FLAGS.disabled_manifest, 'r') as manifest_file: disabled_regex, self._method_types_filter = ( parse_disabled_manifest(manifest_file.read())) if disabled_regex: self.disabled_regex = re.compile(disabled_regex) if FLAGS.tf_xla_flags is not None: os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags
def testGradientTraining(self, data_format): # disable_mlir_bridge for GPUs as there is no legalization for GPU with # MLIR. # TODO(b/189039456): Customize FusedBatchNorm legalization for GPU in MLIR. if test_util.is_mlir_bridge_enabled() and self.device == "XLA_GPU": self.skipTest("b/189039456") # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] grad_val = np.random.random_sample(x_shape).astype(np.float32) x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 # The TensorFlow FusedBatchNormGrad training operation takes two inputs with # implementation defined values. In theory the only correct value these # inputs are the corresponding reserve_space_{1|2} outputs from the # FusedBatchNorm training operation. However, in practice, we rely on the # first one being mean on {C|G}PU, and the second one being variance on CPU # and inverse(sqrt(variance + epsilon)) on GPU (we test this assumption # separately). reserve_space_1_val = mean_val if self.device == "XLA_GPU": reserve_space_2_val = np.reciprocal(np.sqrt(var_val + epsilon)) else: reserve_space_2_val = var_val data_format_src = "NHWC" grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src) with self.session() as sess, self.test_scope(): grad_val_converted = test_utils.ConvertBetweenDataFormats( grad_val, data_format_src, data_format) x_val_converted = test_utils.ConvertBetweenDataFormats( x_val, data_format_src, data_format) grad_x_ref_converted = test_utils.ConvertBetweenDataFormats( grad_x_ref, data_format_src, data_format) grad = array_ops.placeholder(np.float32, shape=x_val_converted.shape, name="grad") x = array_ops.placeholder(np.float32, shape=x_val_converted.shape, name="x") reserve_space_1 = array_ops.placeholder(np.float32, shape=scale_shape, name="reserve_space_1") reserve_space_2 = array_ops.placeholder(np.float32, shape=scale_shape, name="reserve_space_2") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( grad, x, scale, reserve_space_1, reserve_space_2, data_format=data_format, is_training=True) grad_x_val, grad_scale_val, grad_offset_val = sess.run( [grad_x, grad_scale, grad_offset], { grad: grad_val_converted, x: x_val_converted, reserve_space_1: reserve_space_1_val, reserve_space_2: reserve_space_2_val, scale: scale_val }) self.assertAllClose(grad_x_val, grad_x_ref_converted, atol=1e-2) self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)
def setUp(self): super().setUp() self.rewrite_ops_for_tpu = ("TPU" in self.device and test_util.is_mlir_bridge_enabled())
def __init__(self, method_name='runTest'): super(XLATestCase, self).__init__(method_name) if 'XLA' in FLAGS.test_device: context.context().enable_xla_devices() # Check if the mlir bridge has been explicitly enabled or disabled. If # is_mlir_bridge_enabled() returns None, the user did not explictly enable # or disable the bridge so do not update enable_mlir_bridge. if test_util.is_mlir_bridge_enabled(): context.context().enable_mlir_bridge = True elif test_util.is_mlir_bridge_enabled() is not None: context.context().enable_mlir_bridge = False self.device = FLAGS.test_device self.has_custom_call = (self.device == 'XLA_CPU') # Some tests (e.g. ftrl_ops) only work if the program goes through the # _TPUCompileMLIR op. They will set this flag to True. # TODO(kramm): Flip to true (and enable MLIR bridge) for more tests. self.rewrite_ops_for_tpu = False self._all_tf_types = set([ dtypes.as_dtype(types_pb2.DataType.Value(name)) for name in FLAGS.types.split(',') ]) self.int_tf_types = set( [dtype for dtype in self._all_tf_types if dtype.is_integer]) self._float_tf_types = set( [dtype for dtype in self._all_tf_types if dtype.is_floating]) self.complex_tf_types = set( [dtype for dtype in self._all_tf_types if dtype.is_complex]) self._numeric_tf_types = set(self.int_tf_types | self._float_tf_types | self.complex_tf_types) self.quantized_tf_types = set(dtype for dtype in self._all_tf_types if dtype.is_quantized) # Quantized types don't have a numpy equivalent, include them in # all_tf_types but not in all_types. # TODO(b/115960798): Parametrize tests on TF types instead of numpy types # and remove all_types. self._all_types = set(dtype.as_numpy_dtype for dtype in self._all_tf_types if not dtype.is_quantized) self._int_types = set( [dtype.as_numpy_dtype for dtype in self.int_tf_types]) self.signed_int_types = set(dtype.as_numpy_dtype for dtype in self.int_tf_types if not dtype.is_unsigned) self.unsigned_int_types = set(dtype.as_numpy_dtype for dtype in self.int_tf_types if dtype.is_unsigned) self._float_types = set( [dtype.as_numpy_dtype for dtype in self._float_tf_types]) self.complex_types = set( [dtype.as_numpy_dtype for dtype in self.complex_tf_types]) self._numeric_types = set(self._int_types | self._float_types | self.complex_types) # Parse the manifest file, if any, into a regex identifying tests to # disable # TODO(xpan): Make it text proto if it doesn't scale. # Each line of the manifest file specifies an entry. The entry can be # 1) TestNameRegex // E.g. CumprodTest.* Or # 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16 # The 1) disables the entire test. While 2) only filter some numeric types # so that they are not used in those tests. self.disabled_regex = None self._method_types_filter = {} if FLAGS.disabled_manifest is not None: with open(FLAGS.disabled_manifest, 'r') as manifest_file: disabled_regex, self._method_types_filter = ( parse_disabled_manifest(manifest_file.read())) if disabled_regex: self.disabled_regex = re.compile(disabled_regex) if FLAGS.tf_xla_flags is not None: os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags