Пример #1
0
class MixedPrecisionTest(absltest.TestCase):
    @with_policy(InnerModule, jmp.get_policy('p=f16,c=f32,o=f16'))
    def test_set_global_policy(self):
        def f():
            mod = InnerModule()
            return mod(), mod.w

        params, (ret, w) = transform_and_run_once(f)

        self.assertEqual(ret, jnp.float16)
        self.assertEqual(w, jnp.float32)
        self.assertEqual(params['inner_module'], {'w': jnp.float16})

    @with_policy(InnerModule, jmp.get_policy('p=f16,c=f32,o=f16'))
    def test_clear_global_policy(self):
        def f():
            mod = InnerModule()
            return mod(), mod.w

        mixed_precision.clear_policy(InnerModule)

        params, (ret, w) = transform_and_run_once(f)

        self.assertEqual(ret, jnp.bfloat16)
        self.assertEqual(w, jnp.bfloat16)
        self.assertEqual(params['inner_module'], {'w': jnp.bfloat16})

    @with_policy(OuterModule, jmp.get_policy('p=f32,c=f16,o=f32'))
    @with_policy(InnerModule, jmp.get_policy('p=f16,c=f32,o=f32'))
    def test_set_global_policy_nested(self):
        def f():
            outer = OuterModule()
            outer_ret = outer()
            return outer_ret, outer.inner_ret, outer.w, outer.inner.w

        params, (outer_ret, inner_ret, outer_w,
                 inner_w) = transform_and_run_once(f)

        # The return type of the modules should use the output type of the module.
        self.assertEqual(outer_ret, jnp.float32)
        self.assertEqual(inner_ret, jnp.float32)
        # Inside the module we should use the compute type of the policy.
        self.assertEqual(outer_w, jnp.float16)
        self.assertEqual(inner_w, jnp.float32)
        # The parameters returned from init should use the param type of the policy.
        self.assertEqual(params['outer_module'], {'w': jnp.float32})
        self.assertEqual(params['outer_module/inner_module'],
                         {'w': jnp.float16})
Пример #2
0
  def test_current_policy(self):
    policy = jmp.get_policy('p=f16,c=f32,o=f16')
    test = self

    class Foo(module.Module):

      def __call__(self):
        test.assertEqual(mixed_precision.current_policy(), policy)

    class Bar(module.Module):

      def __call__(self):
        test.assertEqual(mixed_precision.current_policy(), policy)
        Foo()()
        test.assertEqual(mixed_precision.current_policy(), policy)

    class Baz(module.Module):

      def __call__(self):
        test.assertIsNone(mixed_precision.current_policy())
        Bar()()
        test.assertIsNone(mixed_precision.current_policy())

    mixed_precision.set_policy(Bar, policy)
    Baz()()
Пример #3
0
 def test_get_policy(self):
   self.assertIsNone(mixed_precision.get_policy(InnerModule))
   policy = jmp.get_policy('p=f16,c=f32,o=f16')
   mixed_precision.set_policy(InnerModule, policy)
   self.assertEqual(mixed_precision.get_policy(InnerModule), policy)
   mixed_precision.clear_policy(InnerModule)
   self.assertIsNone(mixed_precision.get_policy(InnerModule))
Пример #4
0
  def test_set_policy_factory(self):
    def factory():
      class MyModule(module.Module):

        def __call__(self, x):
          return x

      return MyModule

    cls1 = factory()
    cls2 = factory()

    mixed_precision.set_policy(cls1, jmp.get_policy('o=f16'))
    mixed_precision.set_policy(cls2, jmp.get_policy('o=bf16'))
    x = jnp.ones([])
    self.assertEqual(cls1()(x).dtype, jnp.float16)
    self.assertEqual(cls2()(x).dtype, jnp.bfloat16)
Пример #5
0
  def test_policy_for_reloaded_class(self):
    conv_local = conv

    policy = jmp.get_policy('p=f16,c=f32,o=f16')
    mixed_precision.set_policy(conv_local.ConvND, policy)
    conv_local = importlib.reload(conv)

    params, y = transform_and_run_once(
        lambda: conv_local.ConvND(2, 1, 1)(jnp.ones([1, 1, 1, 1])))

    jax.tree_map(lambda p: self.assertEqual(p, jnp.float16), params)
    self.assertEqual(y, jnp.float16)
Пример #6
0
  def test_policy_with_interceptor(self):
    sidechannel = []
    def my_interceptor(next_f, args, kwargs, context):
      sidechannel.append(context)
      return next_f(*args, **kwargs)

    # We need this to make sure that the mixed precision interceptor is
    # installed when we call set_policy (this only happens the first call).
    mixed_precision.reset_thread_local_state_for_test()

    policy = jmp.get_policy('p=f16,c=f32,o=f16')
    with module.intercept_methods(my_interceptor):
      mixed_precision.set_policy(OuterModule, policy)
      x = OuterModule()()
      self.assertEqual(x.dtype, jnp.float16)
    # Outer.init, Outer.call, Inner.init, Inner.call
    self.assertLen(sidechannel, 4)
Пример #7
0
flags.DEFINE_bool('mp_skip_nonfinite', False, help='')
flags.DEFINE_bool('dataset_transpose', False, help='')
flags.DEFINE_bool('dataset_zeros', False, help='')
FLAGS = flags.FLAGS

Scalars = Mapping[str, jnp.ndarray]


class TrainState(NamedTuple):
    params: hk.Params
    state: hk.State
    opt_state: optax.OptState
    loss_scale: jmp.LossScale


get_policy = lambda: jmp.get_policy(FLAGS.mp_policy)
get_bn_policy = lambda: jmp.get_policy(FLAGS.mp_bn_policy)


def get_initial_loss_scale() -> jmp.LossScale:
    cls = getattr(jmp, f'{FLAGS.mp_scale_type}LossScale')
    return cls(FLAGS.mp_scale_value) if cls is not jmp.NoOpLossScale else cls()


def _forward(
    batch: dataset.Batch,
    is_training: bool,
) -> jnp.ndarray:
    """Forward application of the resnet."""
    images = batch['images']
    if FLAGS.dataset_transpose:
Пример #8
0
class MixedPrecisionTest(absltest.TestCase):

  def test_get_policy(self):
    self.assertIsNone(mixed_precision.get_policy(InnerModule))
    policy = jmp.get_policy('p=f16,c=f32,o=f16')
    mixed_precision.set_policy(InnerModule, policy)
    self.assertEqual(mixed_precision.get_policy(InnerModule), policy)
    mixed_precision.clear_policy(InnerModule)
    self.assertIsNone(mixed_precision.get_policy(InnerModule))

  @test_utils.transform_and_run
  def test_current_policy(self):
    policy = jmp.get_policy('p=f16,c=f32,o=f16')
    test = self

    class Foo(module.Module):

      def __call__(self):
        test.assertEqual(mixed_precision.current_policy(), policy)

    class Bar(module.Module):

      def __call__(self):
        test.assertEqual(mixed_precision.current_policy(), policy)
        Foo()()
        test.assertEqual(mixed_precision.current_policy(), policy)

    class Baz(module.Module):

      def __call__(self):
        test.assertIsNone(mixed_precision.current_policy())
        Bar()()
        test.assertIsNone(mixed_precision.current_policy())

    mixed_precision.set_policy(Bar, policy)
    Baz()()

  def test_set_global_policy(self):
    self.assertGlobalPolicy(InnerModule)

  def test_set_global_policy_inner_class(self):
    self.assertGlobalPolicy(InnerModule.InnerInnerModule)

  def test_set_global_policy_local_class(self):
    class LocalModule(InnerModule):
      pass

    self.assertGlobalPolicy(LocalModule)

  def assertGlobalPolicy(self, cls):
    policy = jmp.get_policy('p=f16,c=f32,o=f16')
    with_policy(cls, policy)(self.assertGlobalPolicy_inner)(cls)

  def assertGlobalPolicy_inner(self, cls):
    def f():
      mod = cls(name='inner_module')
      return mod(), mod.w

    params, (ret, w) = transform_and_run_once(f)

    self.assertEqual(ret, jnp.float16)
    self.assertEqual(w, jnp.float32)
    self.assertEqual(params['inner_module'], {'w': jnp.float16})

  @test_utils.transform_and_run
  def test_set_policy_factory(self):
    def factory():
      class MyModule(module.Module):

        def __call__(self, x):
          return x

      return MyModule

    cls1 = factory()
    cls2 = factory()

    mixed_precision.set_policy(cls1, jmp.get_policy('o=f16'))
    mixed_precision.set_policy(cls2, jmp.get_policy('o=bf16'))
    x = jnp.ones([])
    self.assertEqual(cls1()(x).dtype, jnp.float16)
    self.assertEqual(cls2()(x).dtype, jnp.bfloat16)

  @with_policy(InnerModule, jmp.get_policy('p=f16,c=f32,o=f16'))
  def test_clear_global_policy(self):
    def f():
      mod = InnerModule()
      return mod(), mod.w

    mixed_precision.clear_policy(InnerModule)

    params, (ret, w) = transform_and_run_once(f)

    self.assertEqual(ret, jnp.bfloat16)
    self.assertEqual(w, jnp.bfloat16)
    self.assertEqual(params['inner_module'], {'w': jnp.bfloat16})

  @with_policy(OuterModule, jmp.get_policy('p=f32,c=f16,o=f32'))
  @with_policy(InnerModule, jmp.get_policy('p=f16,c=f32,o=f32'))
  def test_set_global_policy_nested(self):
    def f():
      outer = OuterModule()
      outer_ret = outer()
      return outer_ret, outer.inner_ret, outer.w, outer.inner.w

    params, (outer_ret, inner_ret, outer_w, inner_w) = transform_and_run_once(f)

    # The return type of the modules should use the output type of the module.
    self.assertEqual(outer_ret, jnp.float32)
    self.assertEqual(inner_ret, jnp.float32)
    # Inside the module we should use the compute type of the policy.
    self.assertEqual(outer_w, jnp.float16)
    self.assertEqual(inner_w, jnp.float32)
    # The parameters returned from init should use the param type of the policy.
    self.assertEqual(params['outer_module'], {'w': jnp.float32})
    self.assertEqual(params['outer_module/inner_module'], {'w': jnp.float16})

  def test_policy_for_reloaded_class(self):
    conv_local = conv

    policy = jmp.get_policy('p=f16,c=f32,o=f16')
    mixed_precision.set_policy(conv_local.ConvND, policy)
    conv_local = importlib.reload(conv)

    params, y = transform_and_run_once(
        lambda: conv_local.ConvND(2, 1, 1)(jnp.ones([1, 1, 1, 1])))

    jax.tree_map(lambda p: self.assertEqual(p, jnp.float16), params)
    self.assertEqual(y, jnp.float16)

  @test_utils.transform_and_run
  def test_policy_with_interceptor(self):
    sidechannel = []
    def my_interceptor(next_f, args, kwargs, context):
      sidechannel.append(context)
      return next_f(*args, **kwargs)

    # We need this to make sure that the mixed precision interceptor is
    # installed when we call set_policy (this only happens the first call).
    mixed_precision.reset_thread_local_state_for_test()

    policy = jmp.get_policy('p=f16,c=f32,o=f16')
    with module.intercept_methods(my_interceptor):
      mixed_precision.set_policy(OuterModule, policy)
      x = OuterModule()()
      self.assertEqual(x.dtype, jnp.float16)
    # Outer.init, Outer.call, Inner.init, Inner.call
    self.assertLen(sidechannel, 4)
Пример #9
0
 def assertGlobalPolicy(self, cls):
   policy = jmp.get_policy('p=f16,c=f32,o=f16')
   with_policy(cls, policy)(self.assertGlobalPolicy_inner)(cls)