Exemplo n.º 1
0
  def test_enqueue_with_outside_compilation_in_control_flow(self, use_mlir):
    if use_mlir:
      config.enable_mlir_bridge()

    strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
    dataset = self._create_sparse_dataset(strategy)
    dataset_iter = iter(strategy.experimental_distribute_dataset(
        dataset,
        options=distribute_lib.InputOptions(
            experimental_prefetch_to_device=False)))

    # This is one way to force the enqueue in some control flow. @tf.functions
    # aren't inlined in the calling tf.function. An alternative would be to
    # place the enqueue in a switch_v2 or something similar.
    @def_function.function
    def enqueue_fn(features):
      mid_level_api.enqueue(features, training=False)

    @def_function.function
    def enqueue_with_outside_compilation():
      def get_activations(features):
        enqueue_fn(features)
        return mid_level_api.dequeue()
      return strategy.run(get_activations, args=(next(dataset_iter),))

    with self.assertRaisesRegex(
        RuntimeError,
        'does not match graph which contains TPUReplicateContext'):
      enqueue_with_outside_compilation()
Exemplo n.º 2
0
  def test_enqueue_with_outside_compilation(self, use_mlir):
    if use_mlir:
      config.enable_mlir_bridge()

    strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
    dataset = self._create_sparse_dataset(strategy)
    dataset_iter = iter(strategy.experimental_distribute_dataset(
        dataset,
        options=distribute_lib.InputOptions(
            experimental_prefetch_to_device=False)))

    @def_function.function
    def enqueue_with_outside_compilation(data):
      def get_activations(features):
        mid_level_api.enqueue(features, training=False)
        return mid_level_api.dequeue()
      return strategy.run(get_activations, args=(data,))

    @def_function.function
    def enqueue_without_outside_compilation(data):
      def get_activations():
        return mid_level_api.dequeue()
      mid_level_api.enqueue(data, training=False)
      return strategy.run(get_activations)

    features = next(dataset_iter)

    activations_oc = enqueue_with_outside_compilation(features)
    activations = enqueue_without_outside_compilation(features)

    # Extact per core numpy arrays.
    activations_oc0 = self._get_replica_numpy(activations_oc, strategy, 0)
    activations0 = self._get_replica_numpy(activations, strategy, 0)

    self.assertAllClose(activations_oc0, activations0)
Exemplo n.º 3
0
  def testEnableMlirBridge(self):
    # Default value of enable_mlir_bridge is false.
    self.assertFalse(context.context().config.experimental.enable_mlir_bridge)

    # Tests enabling mlir bridge.
    config.enable_mlir_bridge()
    self.assertTrue(context.context().config.experimental.enable_mlir_bridge)

    # Tests disabling mlir bridge.
    config.disable_mlir_bridge()
    self.assertFalse(context.context().config.experimental.enable_mlir_bridge)
Exemplo n.º 4
0
  def testEnableMlirBridge(self):
    # Default value of enable_mlir_bridge is false.
    self.assertFalse(context.context().config.experimental.enable_mlir_bridge)
    self.assertEqual(
        context.context().config.experimental.mlir_bridge_rollout,
        config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_UNSPECIFIED)

    # Tests enabling mlir bridge.
    config.enable_mlir_bridge()
    self.assertTrue(context.context().config.experimental.enable_mlir_bridge)
    self.assertEqual(
        context.context().config.experimental.mlir_bridge_rollout,
        config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED)

    # Tests disabling mlir bridge.
    config.disable_mlir_bridge()
    self.assertFalse(context.context().config.experimental.enable_mlir_bridge)
    self.assertEqual(
        context.context().config.experimental.mlir_bridge_rollout,
        config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_DISABLED)
Exemplo n.º 5
0
  def test_enqueue_cpu_tensor_with_outside_compilation(self, use_mlir):
    if use_mlir:
      config.enable_mlir_bridge()

    strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')

    input_fn = self._create_dense_input_fn(strategy)
    sparse_iter = iter(strategy.distribute_datasets_from_function(input_fn))

    @def_function.function
    def test_fn():
      def get_activations(features):
        mid_level_api.enqueue(features, training=False)
        return mid_level_api.dequeue()

      activations = strategy.run(get_activations, args=(next(sparse_iter),))
      return activations

    with self.assertRaisesRegex(ValueError, 'which is on a TPU input device'):
      test_fn()