def learning_rate(step): """Step to learning rate function.""" ret = 1.0 for name in factors: if name == 'constant': ret *= constant elif name == 'linear_warmup': ret *= np.minimum(1.0, step / warmup_steps) elif name == 'rsqrt_decay': ret /= np.sqrt(np.maximum(step, warmup_steps)) elif name == 'rsqrt_normalized_decay': ret *= np.sqrt(warmup_steps) ret /= np.sqrt(np.maximum(step, warmup_steps)) elif name == 'decay_every': ret *= (decay_factor**(step // steps_per_decay)) elif name == 'cosine_decay': progress = np.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle)) ret *= (0.5 * (1.0 + np.cos(np.pi * (progress % 1.0)))) else: raise ValueError('Unknown factor %s.' % name) ret = np.asarray(ret, dtype=np.float32) return {'learning_rate': ret}
def testTrain(self, layer_id, rng_updater_id, batch_size, trax_has_weights, explicit_build, use_model): """Tests training (forward and backward pass) for TraxKerasLayer. Args: layer_id: an integer, the index into `_LAYERS`. rng_updater_id: an integer, the index into `_RNG_UPDATERS`. batch_size: an integer or `None`, the value for the `batch_size` argument in `TraxKerasLayer.__init__`. trax_has_weights: bool, whether to make the trax layer contain weights at the time when `TraxKerasLayer.build` is called. explicit_build: bool, whether to explicitly call `TraxKerasLayer.build`. use_model: bool, whether to build a `tf.keras.Model` out of the `TraxKerasLayer` layer and use the model to do the training instead of the bare layer. If `True`, we will also test checkpointing and restoring using the model. """ with math_lib.use_backend("tf"): make_trax_layer, input_shapes_no_batch, dtype, allow_none_batch = ( _LAYERS[layer_id]) # We make a fresh trax layer for each test case, so that different test # cases won't interfere with each other. trax_layer = make_trax_layer() if not allow_none_batch and batch_size is None: self.skipTest("This Trax layer can't handle None batch size.") rng_updater = _RNG_UPDATERS[rng_updater_id] input_shapes = math_lib.nested_map(lambda s: [batch_size] + s, input_shapes_no_batch) input_sig = trax2keras.tensor_shapes_to_shape_dtypes( input_shapes, dtype) initializer_rng = math_lib.random.get_prng(765) weights, state = trax_layer.init(input_sig, rng=initializer_rng) generator = tf.random.Generator.from_seed(567) def get_inputs(): return dummy_inputs(generator, input_sig) if trax_has_weights: trax_layer(to_arrays(get_inputs()), weights=weights, state=state) rng = math_lib.random.get_prng(1234) keras_layer = trax2keras.TraxKerasLayer( trax_layer, batch_size=batch_size, initializer_rng=initializer_rng, rng=rng, rng_updater=rng_updater) if explicit_build: keras_layer.build(input_shapes) if use_model: x = tf.keras.Input(shape=input_shapes_no_batch, dtype=dtype) y = keras_layer(x) keras_model = tf.keras.Model(inputs=x, outputs=y) lr = 0.1 # learning rate for _ in range(3): inputs = get_inputs() with tf.GradientTape() as trax_tape: trax_tape.watch([x.data for x in tf.nest.flatten(weights)]) trax_outputs, state = trax_layer.pure_fn(to_arrays(inputs), weights=weights, state=state, rng=rng) trax_grads = trax_tape.gradient( *to_tensors([trax_outputs, weights])) # `g` may be `tf.IndexedSlices`, so we need to `convert_to_tensor` # before multiplication. weights = tf.nest.map_structure( lambda w, g: w + np.asarray(lr * tf.convert_to_tensor(g), w .dtype), weights, trax_grads) rng = rng_updater(rng) with tf.GradientTape() as keras_tape: if use_model: keras_outputs = keras_model(inputs) else: keras_outputs = keras_layer(inputs) if isinstance(keras_outputs, tuple) and len(keras_outputs) == 1: keras_outputs = keras_outputs[0] self.assertAllClose(to_tensors(trax_outputs), keras_outputs) keras_grads = keras_tape.gradient( keras_outputs, keras_layer.trainable_variables) tf.nest.map_structure( lambda v, g: v.assign_add( # pylint: disable=g-long-lambda tf.cast(lr * tf.convert_to_tensor(g), v.dtype)), keras_layer.trainable_variables, keras_grads) self.assertAllClose(to_tensors(weights), read_values(keras_layer._weights), rtol=2e-6, atol=5e-5) self.assertAllClose(to_tensors(state), read_values(keras_layer._state)) self.assertAllClose(to_tensors(rng), read_values(keras_layer._rng)) if use_model: fname = os.path.join(self.get_temp_dir(), "checkpoint") keras_model.save(fname) loaded_model = tf.keras.models.load_model(fname) for _ in range(2): inputs = get_inputs() self.assertAllClose(keras_model(inputs), loaded_model(inputs))