예제 #1
0
 def wrapper(*args, **kwargs):
   try:
     return fn(*args, **kwargs)
   finally:
     # Reset the context.
     context._context = None
     ops.enable_eager_execution_internal()
     assert context._context is not None
예제 #2
0
 def setUp(self):
   super(HardDevicePlacementTest, self).setUp()
   context._context = None
   ops.enable_eager_execution_internal()
   config.set_soft_device_placement(enabled=False)
   context.context().log_device_placement = True
   self.assertEqual(config.get_soft_device_placement(), False)
   self.assertEqual(context.context().soft_device_placement, False)
예제 #3
0
 def setUp(self):
   super(ClusterPlacementTest, self).setUp()
   context._context = None
   ops.enable_eager_execution_internal()
   config.set_soft_device_placement(enabled=True)
   context.context().log_device_placement = True
   workers, _ = test_util.create_local_cluster(2, 0)
   remote.connect_to_remote_host([workers[0].target, workers[1].target])
예제 #4
0
 def wrapper(*args, **kwargs):
     try:
         return fn(*args, **kwargs)
     finally:
         # Reset the context.
         context._reset_jit_compiler_flags()
         context._reset_context()
         ops.enable_eager_execution_internal()
         assert context._context is not None
예제 #5
0
    def testGradientAccumulatorDistributionStrategy(self):
        context._context = None
        ops.enable_eager_execution_internal()
        physical_devices = tf.config.list_physical_devices("CPU")
        if len(physical_devices) == 1:
            tf.config.set_logical_device_configuration(physical_devices[0], [
                tf.config.LogicalDeviceConfiguration(),
                tf.config.LogicalDeviceConfiguration()
            ])
        devices = tf.config.list_logical_devices(device_type="CPU")
        strategy = tf.distribute.MirroredStrategy(devices=devices[:2])

        with strategy.scope():
            accumulator = GradientAccumulator()
            variable = tf.Variable([4.0, 3.0])
            optimizer, _ = create_optimizer(5e-5, 10, 5)
            gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False)

        def accumulate_on_replica(gradient):
            accumulator([gradient])

        def apply_on_replica():
            optimizer.apply_gradients(
                list(zip(accumulator.gradients, [variable])))

        @tf.function
        def accumulate(grad1, grad2):
            with strategy.scope():
                local_variables = strategy.experimental_local_results(
                    gradient_placeholder)
                local_variables[0].assign(grad1)
                local_variables[1].assign(grad2)
                strategy.experimental_run_v2(accumulate_on_replica,
                                             args=(gradient_placeholder, ))

        @tf.function
        def apply_grad():
            with strategy.scope():
                strategy.experimental_run_v2(apply_on_replica)

        def _check_local_values(grad1, grad2):
            values = strategy.experimental_local_results(
                accumulator._gradients[0])
            self.assertListAlmostEqual(values[0].value(), grad1, tol=1e-2)
            self.assertListAlmostEqual(values[1].value(), grad2, tol=1e-2)

        accumulate([1.0, 2.0], [-1.0, 1.0])
        accumulate([3.0, -1.0], [-1.0, -1.0])
        accumulate([-2.0, 2.0], [3.0, -2.0])
        self.assertEqual(accumulator.step, 3)
        _check_local_values([2.0, 3.0], [1.0, -2.0])
        apply_grad()
        self.assertListAlmostEqual(variable.value(), [4.0, 3.0], tol=1e-2)
        accumulator.reset()
        self.assertEqual(accumulator.step, 0)
        _check_local_values([0.0, 0.0], [0.0, 0.0])
예제 #6
0
def _reset_context():
  # See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/config_test.py
  # TODO: find a way to achieve that without relying on TensorFlow private APIs.
  context._context = None
  ops.enable_eager_execution_internal()
예제 #7
0
 def setUp(self):
   super(SoftDevicePlacementTest, self).setUp()
   context._context = None
   ops.enable_eager_execution_internal()
   config.set_soft_device_placement(enabled=True)
   context.context().log_device_placement = True