Example #1
0
 def learning_rate(step):
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == 'constant':
             ret *= constant
         elif name == 'linear_warmup':
             ret *= np.minimum(1.0, step / warmup_steps)
         elif name == 'rsqrt_decay':
             ret /= np.sqrt(np.maximum(step, warmup_steps))
         elif name == 'rsqrt_normalized_decay':
             ret *= np.sqrt(warmup_steps)
             ret /= np.sqrt(np.maximum(step, warmup_steps))
         elif name == 'decay_every':
             ret *= (decay_factor**(step // steps_per_decay))
         elif name == 'cosine_decay':
             progress = np.maximum(0.0, (step - warmup_steps) /
                                   float(steps_per_cycle))
             ret *= (0.5 * (1.0 + np.cos(np.pi * (progress % 1.0))))
         else:
             raise ValueError('Unknown factor %s.' % name)
     ret = np.asarray(ret, dtype=np.float32)
     return {'learning_rate': ret}
Example #2
0
    def testTrain(self, layer_id, rng_updater_id, batch_size, trax_has_weights,
                  explicit_build, use_model):
        """Tests training (forward and backward pass) for TraxKerasLayer.

    Args:
      layer_id: an integer, the index into `_LAYERS`.
      rng_updater_id: an integer, the index into `_RNG_UPDATERS`.
      batch_size: an integer or `None`, the value for the `batch_size` argument
        in `TraxKerasLayer.__init__`.
      trax_has_weights: bool, whether to make the trax layer contain weights at
        the time when `TraxKerasLayer.build` is called.
      explicit_build: bool, whether to explicitly call `TraxKerasLayer.build`.
      use_model: bool, whether to build a `tf.keras.Model` out of the
        `TraxKerasLayer` layer and use the model to do the training instead of
        the bare layer. If `True`, we will also test checkpointing and restoring
        using the model.
    """
        with math_lib.use_backend("tf"):
            make_trax_layer, input_shapes_no_batch, dtype, allow_none_batch = (
                _LAYERS[layer_id])
            # We make a fresh trax layer for each test case, so that different test
            # cases won't interfere with each other.
            trax_layer = make_trax_layer()
            if not allow_none_batch and batch_size is None:
                self.skipTest("This Trax layer can't handle None batch size.")
            rng_updater = _RNG_UPDATERS[rng_updater_id]
            input_shapes = math_lib.nested_map(lambda s: [batch_size] + s,
                                               input_shapes_no_batch)
            input_sig = trax2keras.tensor_shapes_to_shape_dtypes(
                input_shapes, dtype)
            initializer_rng = math_lib.random.get_prng(765)
            weights, state = trax_layer.init(input_sig, rng=initializer_rng)
            generator = tf.random.Generator.from_seed(567)

            def get_inputs():
                return dummy_inputs(generator, input_sig)

            if trax_has_weights:
                trax_layer(to_arrays(get_inputs()),
                           weights=weights,
                           state=state)
            rng = math_lib.random.get_prng(1234)
            keras_layer = trax2keras.TraxKerasLayer(
                trax_layer,
                batch_size=batch_size,
                initializer_rng=initializer_rng,
                rng=rng,
                rng_updater=rng_updater)
            if explicit_build:
                keras_layer.build(input_shapes)
            if use_model:
                x = tf.keras.Input(shape=input_shapes_no_batch, dtype=dtype)
                y = keras_layer(x)
                keras_model = tf.keras.Model(inputs=x, outputs=y)
            lr = 0.1  # learning rate
            for _ in range(3):
                inputs = get_inputs()
                with tf.GradientTape() as trax_tape:
                    trax_tape.watch([x.data for x in tf.nest.flatten(weights)])
                    trax_outputs, state = trax_layer.pure_fn(to_arrays(inputs),
                                                             weights=weights,
                                                             state=state,
                                                             rng=rng)
                trax_grads = trax_tape.gradient(
                    *to_tensors([trax_outputs, weights]))
                # `g` may be `tf.IndexedSlices`, so we need to `convert_to_tensor`
                # before multiplication.
                weights = tf.nest.map_structure(
                    lambda w, g: w + np.asarray(lr * tf.convert_to_tensor(g), w
                                                .dtype), weights, trax_grads)
                rng = rng_updater(rng)
                with tf.GradientTape() as keras_tape:
                    if use_model:
                        keras_outputs = keras_model(inputs)
                    else:
                        keras_outputs = keras_layer(inputs)
                if isinstance(keras_outputs,
                              tuple) and len(keras_outputs) == 1:
                    keras_outputs = keras_outputs[0]
                self.assertAllClose(to_tensors(trax_outputs), keras_outputs)
                keras_grads = keras_tape.gradient(
                    keras_outputs, keras_layer.trainable_variables)
                tf.nest.map_structure(
                    lambda v, g: v.assign_add(  # pylint: disable=g-long-lambda
                        tf.cast(lr * tf.convert_to_tensor(g), v.dtype)),
                    keras_layer.trainable_variables,
                    keras_grads)
                self.assertAllClose(to_tensors(weights),
                                    read_values(keras_layer._weights),
                                    rtol=2e-6,
                                    atol=5e-5)
                self.assertAllClose(to_tensors(state),
                                    read_values(keras_layer._state))
                self.assertAllClose(to_tensors(rng),
                                    read_values(keras_layer._rng))
            if use_model:
                fname = os.path.join(self.get_temp_dir(), "checkpoint")
                keras_model.save(fname)
                loaded_model = tf.keras.models.load_model(fname)
                for _ in range(2):
                    inputs = get_inputs()
                    self.assertAllClose(keras_model(inputs),
                                        loaded_model(inputs))