def run_fn(): with lsgt.LossScaleGradientTape(outer_loss_scale) as g: with lsgt.LossScaleGradientTape(inner_loss_scale) as gg: y = x * x dy_dx = gg.gradient(y, x) d2y_dx2 = g.gradient(dy_dx, x) return dy_dx, d2y_dx2
def run_fn(): x = constant_op.constant(3.0) with lsgt.LossScaleGradientTape(outer_loss_scale) as g: g.watch(x) with lsgt.LossScaleGradientTape(inner_loss_scale) as gg: gg.watch(x) y = x * x dy_dx = gg.gradient(y, x) d2y_dx2 = g.gradient(dy_dx, x) return dy_dx, d2y_dx2
def test_nested_tapes(self, loss_scale): x = constant_op.constant(3.0) with lsgt.LossScaleGradientTape(loss_scale(32)) as g: g.watch(x) with lsgt.LossScaleGradientTape(loss_scale(32)) as gg: gg.watch(x) y = x * x dy_dx = gg.gradient(y, x) self.assertEqual(self.evaluate(dy_dx), 6.0) d2y_dx2 = g.gradient(dy_dx, x) self.assertEqual(self.evaluate(d2y_dx2), 2.0)
def run_fn(): with lsgt.LossScaleGradientTape(ls, persistent=True) as g: y = x * x z = y * y dz_dx = g.gradient(z, x) dy_dx = g.gradient(y, x) return dz_dx, dy_dx
def test_nested_sources(self, loss_scale): x = (variables.Variable(19.0), (variables.Variable(8.), variables.Variable(9.))) with lsgt.LossScaleGradientTape(loss_scale(32)) as g: y = x * 13 dy_dx = g.gradient(y, x) self.assertEqual(self.evaluate(dy_dx), (13., (13., 13.)))
def test_basic_tapes_eager_mode(self, loss_scale): x = constant_op.constant(3.0) with lsgt.LossScaleGradientTape(loss_scale(32)) as g: g.watch(x) y = x * x dy_dx = g.gradient(y, x) self.assertEqual(self.evaluate(dy_dx), 6.0)
def test_scaling_nan_gradient(self, loss_scale): x = constant_op.constant(1.0) with lsgt.LossScaleGradientTape(loss_scale(32)) as g: g.watch(x) y = x * np.nan dy_dx = g.gradient(y, x) self.assertTrue(np.isnan(self.evaluate(dy_dx)))
def test_scaling_inf_gradient(self, loss_scale): x = constant_op.constant(1.0) with lsgt.LossScaleGradientTape(loss_scale(32)) as g: g.watch(x) y = x * np.inf dy_dx = g.gradient(y, x) self.assertEqual(self.evaluate(dy_dx), np.inf)
def test_jacobian_raises_error(self): loss_scale = loss_scale_module.FixedLossScale(2.) x = variables.Variable([1.0, 2.0]) with lsgt.LossScaleGradientTape(loss_scale) as g: y = x * 2 with self.assertRaisesRegexp( NotImplementedError, 'LossScaleGradientTape.jacobian is not yet implemented'): g.jacobian(y, x) x = variables.Variable([[1.0, 2.0], [3.0, 4.0]]) with lsgt.LossScaleGradientTape(loss_scale) as g: y = x * 2 with self.assertRaisesRegexp( NotImplementedError, 'LossScaleGradientTape.batch_jacobian is not yet implemented'): g.batch_jacobian(y, x)
def test_nested_sources(self, loss_scale): x = (constant_op.constant(19.0), (constant_op.constant(8.), constant_op.constant(9.))) with lsgt.LossScaleGradientTape(loss_scale(32)) as g: g.watch(x) y = x * 13 dy_dx = g.gradient(y, x) self.assertEqual(self.evaluate(dy_dx), (13., (13., 13.)))
def test_nested_targets(self, loss_scale): w = variables.Variable(3.0) with lsgt.LossScaleGradientTape(loss_scale(32)) as g: x = w * 5 y = w * 7 z = w * 11 grad = g.gradient([x, (y, z)], w) self.assertEqual(self.evaluate(grad), 23)
def test_nested_targets(self, loss_scale): w = constant_op.constant(3.0) with lsgt.LossScaleGradientTape(loss_scale(32)) as g: g.watch(w) x = w * 5 y = w * 7 z = w * 11 grad = g.gradient([x, (y, z)], w) self.assertEqual(self.evaluate(grad), 23)
def test_non_persistent_tapes_error(self, loss_scale): x = constant_op.constant(3.0) with lsgt.LossScaleGradientTape(loss_scale(32), persistent=False) as g: g.watch(x) y = x * x z = y * y g.gradient(z, x) with self.assertRaisesRegexp(RuntimeError, 'persistent'): g.gradient(y, x)
def test_dynamic_loss_scaling_down_loop(self): loss_scale = loss_scale_module.DynamicLossScale(initial_loss_scale=32) x = constant_op.constant(1.0) with lsgt.LossScaleGradientTape(loss_scale) as g: g.watch(x) y = x * (3.0 * (10**37)) # grad will be inf after scaling dy_dx = g.gradient(y, x) self.assertEqual(self.evaluate(loss_scale()), 8.0) self.assertAllClose(self.evaluate(dy_dx), (3.0 * (10**37)), atol=1e-06)
def test_dynamic_scale_to_one_on_non_finite_gradient( self, non_finite_term): loss_scale = loss_scale_module.DynamicLossScale(initial_loss_scale=32) x = constant_op.constant(1.0) with lsgt.LossScaleGradientTape(loss_scale) as g: g.watch(x) y = x * non_finite_term g.gradient(y, x) self.assertEqual(self.evaluate(loss_scale()), 1.0)
def run_fn(): x = constant_op.constant(3.0) with lsgt.LossScaleGradientTape(ls, persistent=True) as g: g.watch(x) y = x * x z = y * y dz_dx = g.gradient(z, x) dy_dx = g.gradient(y, x) return dz_dx, dy_dx
def run_fn(): with lsgt.LossScaleGradientTape(loss_scale) as g: y1 = x1 * math_ops.cast(x2, 'float16') * math_ops.cast( x3, 'float16') y2 = math_ops.cast(x1, 'float32') * x2 * math_ops.cast( x3, 'float32') y3 = math_ops.cast(x1, 'float64') * math_ops.cast( x2, 'float64') * x3 return g.gradient([y1, y2, y3], [x1, x2, x3])
def test_non_persistent_tapes_error(self): x = variables.Variable(3.0) with lsgt.LossScaleGradientTape(loss_scale_module.FixedLossScale(32), persistent=False) as g: y = x * x z = y * y g.gradient(z, x) with self.assertRaisesRegexp(RuntimeError, 'persistent'): g.gradient(y, x)
def run_fn(): with lsgt.LossScaleGradientTape(loss_scale) as g: # The gradient will be finite on the first replica, and infinite on the # second rep_ctx = distribution_strategy_context.get_replica_context() if rep_ctx.replica_id_in_sync_group == rep_ctx.num_replicas_in_sync - 1: y = x * np.inf else: y = x * 2 return g.gradient(y, x)
def test_fixed_scaling_no_change_non_finite_gradient( self, non_finite_term, is_non_finite): loss_scale = loss_scale_module.FixedLossScale(32) x = constant_op.constant(1.0) with lsgt.LossScaleGradientTape(loss_scale) as g: g.watch(x) y = x * non_finite_term dy_dx = g.gradient(y, x) self.assertTrue(is_non_finite(self.evaluate(dy_dx))) self.assertEqual(self.evaluate(loss_scale()), 32.0)
def test_dynamic_loss_scaling_inf_target_post_scale(self): loss_scale = loss_scale_module.DynamicLossScale( initial_loss_scale=32.0) x = constant_op.constant(3.0 * (10**37)) with lsgt.LossScaleGradientTape(loss_scale) as g: g.watch(x) y = x * 3.0 # target will be inf after scaling dy_dx = g.gradient(y, x) self.assertAllClose(self.evaluate(dy_dx), 3.0) self.assertEqual(self.evaluate(loss_scale()), 32.0)
def test_persistent_tapes(self, loss_scale): x = constant_op.constant(3.0) with lsgt.LossScaleGradientTape(loss_scale(32), persistent=True) as g: g.watch(x) y = x * x z = y * y dz_dx = g.gradient(z, x) self.assertEqual(self.evaluate(dz_dx), 108.0) dy_dx = g.gradient(y, x) self.assertEqual(self.evaluate(dy_dx), 6.0)
def run_fn(): with lsgt.LossScaleGradientTape(loss_scale) as g: y = x * (3.0 * (10**37)) # grad will be inf after scaling return g.gradient(y, x)
def run_fn(): with lsgt.LossScaleGradientTape(loss_scale) as g: y = x * non_finite_term return g.gradient(y, x)
def run_fn(): with lsgt.LossScaleGradientTape(loss_scale) as g: # x6 will have a None gradient because we do not watch it g.watch(x5) y = x1 * x3 * x5 * x6 return g.gradient(y, [x1, x2, [x3, [x4], x5], x6])
def run_fn(): with lsgt.LossScaleGradientTape(loss_scale) as g: y = x * 3.0 # target will be inf after scaling return g.gradient(y, x)
def test_passing_non_loss_scale_raises_error(self): with self.assertRaisesRegexp( ValueError, '`loss_scale` must be an instance of LossScale, but got: 2.0'): lsgt.LossScaleGradientTape(2.0)
def run_fn(): with lsgt.LossScaleGradientTape(loss_scale) as g: y = x * x return g.gradient(y, x, output_gradients=constant_op.constant(2.0))
def run_fn(): with lsgt.LossScaleGradientTape(loss_scale) as g: y = x * x return g.gradient(y, x)
def run_fn(): with lsgt.LossScaleGradientTape(loss_scale) as g: g.watch(x5) y = x1 * x2 * x3 * x4 * x5 return g.gradient(y, [x1, x2, x3, x4, x5])