def test_enqueue_with_outside_compilation_in_control_flow(self, use_mlir): if use_mlir: config.enable_mlir_bridge() strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') dataset = self._create_sparse_dataset(strategy) dataset_iter = iter(strategy.experimental_distribute_dataset( dataset, options=distribute_lib.InputOptions( experimental_prefetch_to_device=False))) # This is one way to force the enqueue in some control flow. @tf.functions # aren't inlined in the calling tf.function. An alternative would be to # place the enqueue in a switch_v2 or something similar. @def_function.function def enqueue_fn(features): mid_level_api.enqueue(features, training=False) @def_function.function def enqueue_with_outside_compilation(): def get_activations(features): enqueue_fn(features) return mid_level_api.dequeue() return strategy.run(get_activations, args=(next(dataset_iter),)) with self.assertRaisesRegex( RuntimeError, 'does not match graph which contains TPUReplicateContext'): enqueue_with_outside_compilation()
def test_enqueue_with_outside_compilation(self, use_mlir): if use_mlir: config.enable_mlir_bridge() strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') dataset = self._create_sparse_dataset(strategy) dataset_iter = iter(strategy.experimental_distribute_dataset( dataset, options=distribute_lib.InputOptions( experimental_prefetch_to_device=False))) @def_function.function def enqueue_with_outside_compilation(data): def get_activations(features): mid_level_api.enqueue(features, training=False) return mid_level_api.dequeue() return strategy.run(get_activations, args=(data,)) @def_function.function def enqueue_without_outside_compilation(data): def get_activations(): return mid_level_api.dequeue() mid_level_api.enqueue(data, training=False) return strategy.run(get_activations) features = next(dataset_iter) activations_oc = enqueue_with_outside_compilation(features) activations = enqueue_without_outside_compilation(features) # Extact per core numpy arrays. activations_oc0 = self._get_replica_numpy(activations_oc, strategy, 0) activations0 = self._get_replica_numpy(activations, strategy, 0) self.assertAllClose(activations_oc0, activations0)
def testEnableMlirBridge(self): # Default value of enable_mlir_bridge is false. self.assertFalse(context.context().config.experimental.enable_mlir_bridge) # Tests enabling mlir bridge. config.enable_mlir_bridge() self.assertTrue(context.context().config.experimental.enable_mlir_bridge) # Tests disabling mlir bridge. config.disable_mlir_bridge() self.assertFalse(context.context().config.experimental.enable_mlir_bridge)
def testEnableMlirBridge(self): # Default value of enable_mlir_bridge is false. self.assertFalse(context.context().config.experimental.enable_mlir_bridge) self.assertEqual( context.context().config.experimental.mlir_bridge_rollout, config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) # Tests enabling mlir bridge. config.enable_mlir_bridge() self.assertTrue(context.context().config.experimental.enable_mlir_bridge) self.assertEqual( context.context().config.experimental.mlir_bridge_rollout, config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED) # Tests disabling mlir bridge. config.disable_mlir_bridge() self.assertFalse(context.context().config.experimental.enable_mlir_bridge) self.assertEqual( context.context().config.experimental.mlir_bridge_rollout, config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_DISABLED)
def test_enqueue_cpu_tensor_with_outside_compilation(self, use_mlir): if use_mlir: config.enable_mlir_bridge() strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') input_fn = self._create_dense_input_fn(strategy) sparse_iter = iter(strategy.distribute_datasets_from_function(input_fn)) @def_function.function def test_fn(): def get_activations(features): mid_level_api.enqueue(features, training=False) return mid_level_api.dequeue() activations = strategy.run(get_activations, args=(next(sparse_iter),)) return activations with self.assertRaisesRegex(ValueError, 'which is on a TPU input device'): test_fn()