Exemple #1
0
 def f(x):
   if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX):
     return jax.device_put_replicated(x, jax.local_devices())
   elif n_devices > 1:
     return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape)
   else:
     return x
Exemple #2
0
 def f(x):
   if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX):
     return _multi_device_put(x)
   elif n_devices > 1:
     return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape)
   else:
     return x
Exemple #3
0
 def learning_rate(step):
   """Step to learning rate function."""
   ret = 1.0
   for name in factors:
     if name == 'constant':
       ret *= constant
     elif name == 'linear_warmup':
       ret *= jnp.minimum(1.0, step / warmup_steps)
     elif name == 'rsqrt_decay':
       ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
     elif name == 'rsqrt_normalized_decay':
       ret *= jnp.sqrt(warmup_steps)
       ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
     elif name == 'decay_every':
       ret *= (decay_factor ** (step//steps_per_decay))
     elif name == 'cosine_decay':
       progress = jnp.maximum(
           0.0, (step - warmup_steps) / float(steps_per_cycle))
       ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
     else:
       raise ValueError('Unknown factor %s.' % name)
   ret = jnp.asarray(ret, dtype=jnp.float32)
   return {'learning_rate': ret}
Exemple #4
0
  def testTrain(self, layer_id, rng_updater_id, batch_size, trax_has_weights,
                explicit_build, use_model):
    """Tests training (forward and backward pass) for AsKeras.

    Args:
      layer_id: an integer, the index into `_LAYERS`.
      rng_updater_id: an integer, the index into `_RNG_UPDATERS`.
      batch_size: an integer or `None`, the value for the `batch_size` argument
        in `AsKeras.__init__`.
      trax_has_weights: bool, whether to make the trax layer contain weights at
        the time when `AsKeras.build` is called.
      explicit_build: bool, whether to explicitly call `AsKeras.build`.
      use_model: bool, whether to build a `tf.keras.Model` out of the
        `AsKeras` layer and use the model to do the training instead of
        the bare layer. If `True`, we will also test checkpointing and restoring
        using the model.
    """
    with trax.fastmath.use_backend("tensorflow-numpy"):
      make_trax_layer, input_shapes_no_batch, dtype, allow_none_batch = (
          _LAYERS[layer_id])
      # We make a fresh trax layer for each test case, so that different test
      # cases won't interfere with each other.
      trax_layer = make_trax_layer()
      if not allow_none_batch and batch_size is None:
        self.skipTest("This Trax layer can't handle None batch size.")
      rng_updater = _RNG_UPDATERS[rng_updater_id]
      input_shapes = math_lib.nested_map(
          lambda s: [batch_size] + s, input_shapes_no_batch)
      input_sig = trax2keras.tensor_shapes_to_shape_dtypes(input_shapes, dtype)
      initializer_rng = math_lib.random.get_prng(765)
      weights, state = trax_layer.init(input_sig, rng=initializer_rng)
      generator = tf.random.Generator.from_seed(567)
      def get_inputs():
        return dummy_inputs(generator, input_sig)
      if trax_has_weights:
        trax_layer(to_arrays(get_inputs()), weights=weights, state=state)
      rng = math_lib.random.get_prng(1234)
      keras_layer = trax2keras.AsKeras(
          trax_layer, batch_size=batch_size, initializer_rng=initializer_rng,
          rng=rng, rng_updater=rng_updater)
      if explicit_build:
        keras_layer.build(input_shapes)
      if use_model:
        x = tf.keras.Input(shape=input_shapes_no_batch, dtype=dtype)
        y = keras_layer(x)
        keras_model = tf.keras.Model(inputs=x, outputs=y)
      lr = 0.1  # learning rate
      for _ in range(3):
        inputs = get_inputs()
        with tf.GradientTape() as trax_tape:
          trax_tape.watch([x.data for x in tf.nest.flatten(weights)])
          trax_outputs, state = trax_layer.pure_fn(
              to_arrays(inputs), weights=weights, state=state, rng=rng)
        trax_grads = trax_tape.gradient(*to_tensors([trax_outputs, weights]))
        # `g` may be `tf.IndexedSlices`, so we need to `convert_to_tensor`
        # before multiplication.
        weights = tf.nest.map_structure(
            lambda w, g: w + jnp.asarray(lr * tf.convert_to_tensor(g), w.dtype),
            weights, trax_grads)
        rng = rng_updater(rng)
        with tf.GradientTape() as keras_tape:
          if use_model:
            keras_outputs = keras_model(inputs)
          else:
            keras_outputs = keras_layer(inputs)
        if isinstance(keras_outputs, tuple) and len(keras_outputs) == 1:
          keras_outputs = keras_outputs[0]
        self.assertAllClose(to_tensors(trax_outputs), keras_outputs, atol=1e-5)
        keras_grads = keras_tape.gradient(keras_outputs,
                                          keras_layer.trainable_variables)
        tf.nest.map_structure(
            lambda v, g: v.assign_add(  # pylint: disable=g-long-lambda
                tf.cast(lr * tf.convert_to_tensor(g), v.dtype)),
            keras_layer.trainable_variables, keras_grads)
        self.assertAllClose(
            to_tensors(weights), read_values(keras_layer._weights),
            rtol=2e-6, atol=3.5e-4 if has_gpu() else 1e-6)
        self.assertAllClose(to_tensors(state), read_values(keras_layer._state))
        self.assertAllClose(to_tensors(rng), read_values(keras_layer._rng))
      if use_model:
        fname = os.path.join(self.get_temp_dir(), "checkpoint")
        keras_model.save(fname)
        loaded_model = tf.keras.models.load_model(fname)
        for _ in range(2):
          inputs = get_inputs()
          self.assertAllClose(keras_model(inputs), loaded_model(inputs))