def test_forward_dtype(self, backend, dtype): with math.use_backend(backend): layer = tl.BatchNorm() x = np.ones((3, 2, 7)).astype(dtype) _, _ = layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.dtype, dtype)
def test_no_int32_or_uint32_returned(self): """Tests that Trainer._jit_update_fn doesn't return int32 or uint32. TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled computation to copy int32/uint32 outputs to CPU. This test makes sure that won't happen. """ if xla_bridge.device_count() > 1: self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with math.use_backend('tf'), self.tmp_dir() as output_dir: n_classes = 1001 model_fn = functools.partial(models.Resnet50, n_output_classes=n_classes) inputs = test_inputs(n_classes, input_shape=(224, 224, 3)) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=layers.CrossEntropyLoss, optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, inputs=inputs, ) trainer.reset(output_dir) trainer.train_epoch(1, 0) # Those are the things returned by Trainer._jit_update_fn arrays = (trainer._opt_state.weights, trainer._opt_state.slots, trainer._model_state, trainer._rngs) arrays = tf.nest.flatten(arrays) for x in arrays: if isinstance(x, np.ndarray) and (x.dtype == np.int32 or x.dtype == np.uint32): raise ValueError('Found an array of int32 or uint32: %s' % x)
def test_batching_lsh_self_attention(self): with math.use_backend('jax'): common_kwargs = dict( n_heads=6, d_qk=7, d_v=17, causal=True, chunk_len=5, n_chunks_before=1, n_chunks_after=0, n_hashes=2, n_buckets=4, attention_dropout=0.2, output_dropout=0.1, mode='train', ) test_kwargs = [] for n_parallel_heads in [1, 3, 6, 12]: for use_python_loop in [True, False]: test_kwargs.append( dict(n_parallel_heads=n_parallel_heads, use_python_loop=use_python_loop)) x = jax.random.uniform(jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32) input_signature = shapes.signature(x) self._test_equivalence_to_reference_code( efficient_attention.LSHSelfAttention, x, input_signature, common_kwargs, *test_kwargs)
def test_train_restart(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == 'tf': self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with math.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = test_inputs(n_classes) # Train and evaluate trainer_lib.train(output_dir, model=model_fn, inputs=inputs, steps=steps, eval_steps=eval_steps) # Restart training state = trainer_lib.train(output_dir, model=model_fn, inputs=inputs, steps=(2 * steps), eval_steps=eval_steps) # Assert total train steps self.assertEqual(state.step, 2 * steps)
def test_batching_self_attention(self): with math.use_backend('jax'): common_kwargs = dict( n_heads=6, d_qk=7, d_v=17, share_qk=False, causal=True, chunk_len=8, n_chunks_before=1, n_chunks_after=0, attention_dropout=0.0, mode='train', ) test_kwargs = [] for n_parallel_heads in [1, 3, 6, 12]: for use_python_loop in [True, False]: test_kwargs.append( dict(n_parallel_heads=n_parallel_heads, use_python_loop=use_python_loop)) inp = jax.random.uniform(jax.random.PRNGKey(0), (2, 16, 64), dtype=np.float32) input_signature = ShapeDtype((2, 16, 64), dtype=np.float32) self._test_equivalence_to_reference_code( efficient_attention_v2.SelfAttention, inp, input_signature, common_kwargs, *test_kwargs)
def test_fast_inference_self_attention(self): with math.use_backend('jax'): common_kwargs = dict( n_heads=6, d_qk=7, d_v=17, share_qk=False, causal=True, chunk_len=5, n_chunks_before=1, n_chunks_after=0, attention_dropout=0.0, output_dropout=0.0, ) test_kwargs = [] for n_parallel_heads in [1, 3, 6, 12]: for use_python_loop in [True, False]: test_kwargs.append( dict(n_parallel_heads=n_parallel_heads, use_python_loop=use_python_loop)) x = jax.random.uniform(jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32) input_signature = shapes.signature(x) self._test_fast_inference(efficient_attention.SelfAttention, x, input_signature, common_kwargs, *test_kwargs)
def test_train_eval_predict_sm3(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == 'tf': self.skipTest("tf-numpy backend doesn't support multi-devices yet.") with math.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 steps = 2 eval_steps = 2 model_fn = functools.partial( models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = test_inputs(n_classes) # Train and evaluate state = trainer_lib.train( output_dir, model=model_fn, inputs=inputs, steps=steps, eval_steps=eval_steps, optimizer=trax_opt.SM3) # Assert total train steps self.assertEqual(steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get('train', 'metrics/accuracy') eval_acc = state.history.get('eval', 'metrics/accuracy') self.assertEqual(len(train_acc), len(eval_acc)) self.assertLen(eval_acc, 2) # Predict with weights loaded from file. inputs = inputs.train_stream(1) model = model_fn() model.init_from_file(os.path.join(output_dir, 'model.pkl')) model(next(inputs)[0])
def _test_lsh_self_attention_deterministic_given_seed(self, causal=False): # Once the initialization and the call seeds are pinned down we have # deterministic output. with math.use_backend('jax'): layer = efficient_attention.LSHSelfAttention( n_heads=5, d_qk=7, d_v=17, causal=causal, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=2, n_buckets=4, use_reference_code=True, attention_dropout=0.0, mode='train') x = np.ones((3, 32, 8)).astype(np.float32) def get_output(): _, _ = layer.init(shapes.signature(x), jax.random.PRNGKey(0)) return layer(x, rng=jax.random.PRNGKey(1)) ys = [get_output() for _ in range(10)] self.assertEqual(ys[0].shape, x.shape) for y in ys[1:]: np.testing.assert_array_almost_equal(ys[0], y, decimal=6)
def test_reformer_rng_consistency(self): with math.use_backend('jax'): vocab_size = 16 batch_size = 1 input_sd = ShapeDtype((batch_size, 8), np.int32) input_signature = (input_sd, input_sd) model = reformer.ReformerLM( vocab_size, d_model=32, d_ff=64, d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, max_len=16, n_chunks=2, n_attention_chunks=1, mode='train', attention_type=PoisonOnRNGMismatchAttention) rng = math.random.get_prng(0) weights, state = model.init(input_signature) def dummy_loss_fn(weights): inputs = (np.zeros(input_sd.shape, dtype=np.int32), ) * 2 output = model(inputs, weights=weights, state=state, rng=rng) dummy_loss = math.numpy.sum(output[0]) return dummy_loss grad_fn = math.grad(dummy_loss_fn) grads = grad_fn(weights) # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch. for grad in jax.tree_util.tree_leaves(grads): assert onp.all(onp.isfinite(grad))
def test_custom_zero_grad(self, backend_name): class IdWithZeroGrad(base.Layer): def forward(self, x, weights): return x @property def has_backward(self): return True def backward(self, inputs, output, grad, weights, state, new_state, rng): return (jnp.zeros_like(grad), ()) with math.use_backend(backend_name): layer = IdWithZeroGrad() rng = math.random.get_prng(0) input_signature = shapes.ShapeDtype((9, 17)) random_input = math.random.uniform(rng, input_signature.shape, minval=-1.0, maxval=1.0) layer.init(input_signature) f = lambda x: jnp.mean(layer(x)) grad = math.grad(f)(random_input) self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. self.assertEqual(sum(sum(grad * grad)), 0.0) # Each one is 0.
def _test_fast_inference(self, attention_type, length): with math.use_backend('jax'): vocab_size = 16 model_fn = functools.partial( transformer.TransformerLM, vocab_size=vocab_size, d_model=4, d_ff=8, n_layers=2, n_heads=2, attention_type=attention_type, ) model_slow = model_fn(mode='eval') model_fast = model_fn(mode='predict') rng = math.random.get_prng(0) batch_size = 2 input_signature = ShapeDtype((batch_size, 1), np.int32) # Given the same rng, both models initialize with the same parameters. model_slow.init(input_signature) model_fast.init(input_signature) buf = onp.zeros((batch_size, length), dtype=np.int32) next_sym = onp.zeros((batch_size, 1), dtype=onp.int32) for index in range(length): logits_slow = model_slow(buf, rng=rng) logits_fast = model_fast(next_sym, rng=rng) onp.testing.assert_array_almost_equal( logits_slow[:, index, :], logits_fast[:, 0, :], decimal=5, ) next_sym = onp.random.randint(vocab_size, size=(batch_size, 1)) buf[:, index] = next_sym[:, 0]
def nontrainable_params(self): # TODO(afrozm): Give further thought to this name. # TODO(lukaszkaiser): it makes no sense to use an accelerator (e.g. TPU) # in op-by-op mode just to compute the learning rate. However, there # should be a cleaner approach that forceably swapping out the backend. with math.use_backend('numpy'): return self._lr_fn(self._step)
def test_lsh_self_attention_masked_non_causal(self): # Test that when the input that is in the masked area changes the attention # for the un-masked outputs doesn't change, but the masked region does # change. with math.use_backend('jax'): layer = efficient_attention.LSHSelfAttention( n_heads=5, d_qk=7, d_v=17, causal=False, masked=True, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=2, n_buckets=4, use_reference_code=True, attention_dropout=0.0, mode='train') batch = 5 max_len = 32 hidden = 8 x = np.random.uniform(size=(batch, max_len, hidden)) mask = np.ones((batch, max_len)).astype(np.bool) rngs = jax.random.randint(jax.random.PRNGKey(0), (batch, ), minval=1, maxval=max_len - 1) # Set some suffix of each mask[b] to 0. for i in range(batch): mask[i, rngs[i]:] = 0 # Fix rngs and get the output for the LSH layer. def get_output(x, mask): xs = [x, mask] _, _ = layer.init(shapes.signature(xs), jax.random.PRNGKey(0)) return layer(xs, rng=jax.random.PRNGKey(1)) # Get the attention output for masked x. y = get_output(x, mask) # Change x, but only in the masked regions. for i in range(batch): x[i, rngs[i]:] = np.random.uniform(size=(max_len - rngs[i], hidden)) y2 = get_output(x, mask) for i in range(batch): # y and y2 should be identical in the non-masked part. np.testing.assert_array_almost_equal(y[i, :rngs[i]], y2[i, :rngs[i]], decimal=6) # In the masked out part, they should be different. self.assertGreater( np.mean(np.abs(y[i, rngs[i]:] - y2[i, rngs[i]:])), 1e-5)
def test_layer_norm_dtype(self, backend, dtype): with use_backend(backend): input_shape = (2, 3, 4) input_signature = ShapeDtype(input_shape, dtype) layer = normalization.LayerNorm() layer.init(input_signature) out = layer(onp.empty(input_shape, dtype=dtype)) self.assertEqual(out.dtype, dtype)
def test_takes_new_history(self): histories = np.array([[[0, 1, 2]], [[3, 4, 5]]]) with math.use_backend('numpy'): env = self._create_env( # pylint: disable=no-value-for-parameter model=mock.MagicMock(), histories=histories, trajectory_length=2, ) env.reset() observation = env.reset() np.testing.assert_array_equal(observation, [5])
def test_fails_to_evaluate_model_with_matrix_observation_space(self): with math.use_backend('numpy'): env = self._make_env( # pylint: disable=no-value-for-parameter observation_space=gym.spaces.Box(shape=(2, 2), low=0, high=1), action_space=gym.spaces.Discrete(n=1), max_trajectory_length=2, batch_size=1, ) trajectories = [ self._make_trajectory(np.array([[0, 1], [2, 3]]), np.array([0]))] metrics = simple.evaluate_model(env, trajectories, plt) self.assertIsNone(metrics)
def test_communicates_with_model(self): # Mock model increasing the observation by action, reward is the parity of # the new observation. def mock_transition(inputs, *args, **kwargs): del args del kwargs (observations, actions) = inputs new_observations = observations[:, -1] + actions rewards = np.array([[int(new_observations % 2 == 0)]]) return (new_observations, rewards) mock_model_fn = mock.MagicMock() mock_model_fn.return_value.side_effect = mock_transition mock_model_fn.return_value.init.return_value = (None, None) mock_model = mock_model_fn.return_value actions_to_take = np.array([[1], [3]]) histories = np.array([[[0, 1, 2, 3]]]) expected_observations = np.array([[3], [4], [7]]) expected_rewards = np.array([[1], [0]]) expected_dones = np.array([[False], [True]]) expected_histories = np.array([[[0, 1, 2, 3]], [[1, 2, 3, 4]]]) expected_actions = actions_to_take with math.use_backend('numpy'): env = self._create_env( # pylint: disable=no-value-for-parameter model=mock_model_fn, histories=histories, trajectory_length=len(actions_to_take), ) actual_observations = [env.reset()] actual_rewards = [] actual_dones = [] actual_histories = [] actual_actions = [] for action in actions_to_take: (observation, reward, done, _) = env.step(action) actual_observations.append(observation) actual_rewards.append(reward) actual_dones.append(done) # Mock call is a tuple (args, kwargs). There is one positional argument, # which is a tuple (history, action). (((history, action), ), _) = mock_model.call_args actual_actions.append(action) actual_histories.append(history) np.testing.assert_array_equal(actual_observations, expected_observations) np.testing.assert_array_equal(actual_rewards, expected_rewards) np.testing.assert_array_equal(actual_dones, expected_dones) np.testing.assert_array_equal(actual_histories, expected_histories) np.testing.assert_array_equal(actual_actions, expected_actions)
def __init__(self, trax_layer, batch_size=None, initializer_rng=None, rng=None, rng_updater=None, dtype=None): """Creates a Keras layer wrapping around a Trax layer. Args: trax_layer: an object of class `trax.layers.Layer`, the trax layer to wrap. batch_size: (optional) an integer, the batch size that this Keras layer will be used on. Keras sometimes needs to generate a TF graph for a layer (e.g. for acceleration or checkpointing). The inputs used to trace the graph will have `None` as the length of their batch dimensions, so as to generate a graph that can handle any batch size. Some Trax layers can't handle tensors whose shapes contain `None`. If `batch_size` is set to an integer, the graph will be traced with `batch_size` as the batch size instead of `None`. Note that in this case the graph (and the Keras layer) can only be used on a specific batch size. If you want to use a different batch size, you need to create another `TraxKerasLayer` object with a different `batch_size`. initializer_rng: (optional) an RNG key used to create the weights and state if `trax_layer` doesn't have them. If `None`, `trax.math.random.get_prng(0)` will be used. rng: (optional) an RNG key for the forward function (aka the "forward key"). If `None`, `trax.math.random.get_prng(0)` will be used. rng_updater: (optional) a function of type rng_key -> rng_key, used to update the forward key after each forward pass. If `None`, the function `lambda x: trax.math.random.split(x, 1)[0]` will be used, which advances the RNG key. dtype: (optional) the dtype of the inputs. See the `dtype` argument of `tf.keras.layers.Layer.__init__` for details. """ super(TraxKerasLayer, self).__init__(dtype=dtype) with math_lib.use_backend("tf"): if initializer_rng is None: initializer_rng = math_lib.random.get_prng(0) if rng is None: rng = math_lib.random.get_prng(0) if rng_updater is None: rng_updater = lambda x: math_lib.random.split(x, 1)[0] self._trax_layer = trax_layer self._batch_size = batch_size self._initializer_rng = initializer_rng self._forward_rng_init = rng self._rng_updater = rng_updater
def test_self_attention(self): with math.use_backend('jax'): input_signature = ShapeDtype((3, 32, 8)) layer = efficient_attention.SelfAttention(n_heads=5, d_qk=7, d_v=17, share_qk=False, causal=True, chunk_len=8, n_chunks_before=1, n_chunks_after=0, use_reference_code=True, attention_dropout=0.0, mode='train') final_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual((3, 32, 8), final_shape)
def _test_train_eval_predict(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == 'tf': self.skipTest("tf-numpy backend does't support multi-devices yet.") with math.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 steps = 2 eval_steps = 2 # Adds Dropout and BatchNorm to test state handling. def model_fn(mode='train'): return layers.Serial( layers.Dropout(mode=mode, rate=0.1), layers.BatchNorm(mode=mode), models.MLP(d_hidden=16, n_output_classes=n_classes, mode=mode)) inputs = test_inputs(n_classes) # Train and evaluate state = trainer_lib.train(output_dir, model=model_fn, inputs=inputs, steps=steps, eval_steps=eval_steps) # Assert total train steps self.assertEqual(steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get('train', 'metrics/accuracy') eval_acc = state.history.get('eval', 'metrics/accuracy') self.assertEqual(len(train_acc), len(eval_acc)) self.assertLen(eval_acc, 2) # Predict with final weights inputs = inputs.train_stream(1) model = model_fn() weights = state.opt_state.weights[0] state = state.model_state[0] if xla_bridge.device_count() > 1: unreplicate = lambda x: x[0] weights = math.nested_map(unreplicate, weights) state = math.nested_map(unreplicate, state) model(next(inputs)[0], weights=weights, state=state)
def test_self_attention_tf(self): with math.use_backend('tf'): layer = efficient_attention.SelfAttention(n_heads=5, d_qk=7, d_v=17, share_qk=False, causal=True, chunk_len=8, n_chunks_before=1, n_chunks_after=0, use_reference_code=True, attention_dropout=0.0, mode='train') x = np.ones((3, 32, 8)).astype(np.float32) _, _ = layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, x.shape)
def call(self, inputs): with math_lib.use_backend("tf"): inputs = math_lib.nested_map( functools.partial(_replace_none_batch, batch_size=self._batch_size), inputs) weights, state, rng = read_values( [self._weights, self._state, self._rng]) inputs, weights, state, rng = to_arrays( [inputs, weights, state, rng]) outputs, new_state = self._trax_layer.pure_fn(inputs, weights=weights, state=state, rng=rng) tf.nest.map_structure(lambda v, t: v.assign(t), self._state, new_state) self._rng.assign(self._rng_updater(rng)) outputs = to_tensors(outputs) return outputs
def test_evaluates_model_with_vector_observation_space(self): with math.use_backend('numpy'): env = self._make_env( # pylint: disable=no-value-for-parameter observation_space=gym.spaces.Box(shape=(2,), low=0, high=1), action_space=gym.spaces.Discrete(n=1), max_trajectory_length=2, batch_size=3, ) trajectories = [ self._make_trajectory(observations, actions) # pylint: disable=g-complex-comprehension for (observations, actions) in [ (np.array([[0, 1]]), np.array([0])), (np.array([[1, 2], [3, 4]]), np.array([0, 0])), (np.array([[1, 2], [3, 4], [5, 6]]), np.array([0, 0, 0])), ] ] metrics = simple.evaluate_model(env, trajectories, plt) self.assertIsNotNone(metrics) self.assertEqual(len(metrics), 2)
def test_reformer_lm_forward_shape_tf(self): with math.use_backend('tf'): vocab_size = 16 timebin_attn = self._timebin_self_attention_fn( use_reference_code=True) model = reformer.ReformerLM(vocab_size, d_model=32, d_ff=64, d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, max_len=64, attention_type=timebin_attn) xs = [ np.ones((1, 64)).astype(np.int32), np.ones((1, 64)).astype(np.int32) ] _, _ = model.init(shapes.signature(xs)) ys = model(xs) self.assertEqual([y.shape for y in ys], [(1, 64, 16), (1, 64)])
def test_reset_twice(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == 'tf': self.skipTest("tf-numpy backend doesn't support multi-devices yet.") with math.use_backend(backend_name), self.tmp_dir() as output_dir1, \ self.tmp_dir() as output_dir2: n_classes = 4 model_fn = functools.partial( models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = test_inputs(n_classes) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=layers.CrossEntropyLoss(), optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, inputs=inputs, ) trainer.reset(output_dir1) trainer.evaluate(1) trainer.reset(output_dir2) trainer.evaluate(1)
def build(self, input_shape): with math_lib.use_backend("tf"): # Using `is` instead of `==` following Trax's practice if self._trax_layer.weights is base.EMPTY_WEIGHTS: sanitized_input_shape = math_lib.nested_map( functools.partial(_replace_none_batch, batch_size=self._batch_size), input_shape) weights, state = self._trax_layer.init( tensor_shapes_to_shape_dtypes(sanitized_input_shape, self.dtype), rng=self._initializer_rng) else: weights = self._trax_layer.weights state = self._trax_layer.state # Note: `weights` may contain `EMPTY_WEIGHTS` self._weights = math_lib.nested_map( functools.partial(tf.Variable, trainable=True), weights) self._state = math_lib.nested_map( functools.partial(tf.Variable, trainable=False), state) self._rng = tf.Variable(self._forward_rng_init, trainable=False) super(TraxKerasLayer, self).build(input_shape)
def testTrain(self, layer_id, rng_updater_id, batch_size, trax_has_weights, explicit_build, use_model): """Tests training (forward and backward pass) for TraxKerasLayer. Args: layer_id: an integer, the index into `_LAYERS`. rng_updater_id: an integer, the index into `_RNG_UPDATERS`. batch_size: an integer or `None`, the value for the `batch_size` argument in `TraxKerasLayer.__init__`. trax_has_weights: bool, whether to make the trax layer contain weights at the time when `TraxKerasLayer.build` is called. explicit_build: bool, whether to explicitly call `TraxKerasLayer.build`. use_model: bool, whether to build a `tf.keras.Model` out of the `TraxKerasLayer` layer and use the model to do the training instead of the bare layer. If `True`, we will also test checkpointing and restoring using the model. """ with math_lib.use_backend("tf"): make_trax_layer, input_shapes_no_batch, dtype, allow_none_batch = ( _LAYERS[layer_id]) # We make a fresh trax layer for each test case, so that different test # cases won't interfere with each other. trax_layer = make_trax_layer() if not allow_none_batch and batch_size is None: self.skipTest("This Trax layer can't handle None batch size.") rng_updater = _RNG_UPDATERS[rng_updater_id] input_shapes = math_lib.nested_map(lambda s: [batch_size] + s, input_shapes_no_batch) input_sig = trax2keras.tensor_shapes_to_shape_dtypes( input_shapes, dtype) initializer_rng = math_lib.random.get_prng(765) weights, state = trax_layer.init(input_sig, rng=initializer_rng) generator = tf.random.Generator.from_seed(567) def get_inputs(): return dummy_inputs(generator, input_sig) if trax_has_weights: trax_layer(to_arrays(get_inputs()), weights=weights, state=state) rng = math_lib.random.get_prng(1234) keras_layer = trax2keras.TraxKerasLayer( trax_layer, batch_size=batch_size, initializer_rng=initializer_rng, rng=rng, rng_updater=rng_updater) if explicit_build: keras_layer.build(input_shapes) if use_model: x = tf.keras.Input(shape=input_shapes_no_batch, dtype=dtype) y = keras_layer(x) keras_model = tf.keras.Model(inputs=x, outputs=y) lr = 0.1 # learning rate for _ in range(3): inputs = get_inputs() with tf.GradientTape() as trax_tape: trax_tape.watch([x.data for x in tf.nest.flatten(weights)]) trax_outputs, state = trax_layer.pure_fn(to_arrays(inputs), weights=weights, state=state, rng=rng) trax_grads = trax_tape.gradient( *to_tensors([trax_outputs, weights])) # `g` may be `tf.IndexedSlices`, so we need to `convert_to_tensor` # before multiplication. weights = tf.nest.map_structure( lambda w, g: w + np.asarray(lr * tf.convert_to_tensor(g), w .dtype), weights, trax_grads) rng = rng_updater(rng) with tf.GradientTape() as keras_tape: if use_model: keras_outputs = keras_model(inputs) else: keras_outputs = keras_layer(inputs) if isinstance(keras_outputs, tuple) and len(keras_outputs) == 1: keras_outputs = keras_outputs[0] self.assertAllClose(to_tensors(trax_outputs), keras_outputs) keras_grads = keras_tape.gradient( keras_outputs, keras_layer.trainable_variables) tf.nest.map_structure( lambda v, g: v.assign_add( # pylint: disable=g-long-lambda tf.cast(lr * tf.convert_to_tensor(g), v.dtype)), keras_layer.trainable_variables, keras_grads) self.assertAllClose(to_tensors(weights), read_values(keras_layer._weights), rtol=2e-6, atol=5e-5) self.assertAllClose(to_tensors(state), read_values(keras_layer._state)) self.assertAllClose(to_tensors(rng), read_values(keras_layer._rng)) if use_model: fname = os.path.join(self.get_temp_dir(), "checkpoint") keras_model.save(fname) loaded_model = tf.keras.models.load_model(fname) for _ in range(2): inputs = get_inputs() self.assertAllClose(keras_model(inputs), loaded_model(inputs))
def test_communicates_with_model(self, mock_restore_state): gin.bind_parameter('BoxSpaceSerializer.precision', 1) vocab_size = 16 # Mock model predicting a fixed sequence of symbols. It is made such that # the first two observations are different and the last one is equal to the # first. symbols = [ 1, 1, 2, 2, 0, 0, # obs1 act1 1, 2, 2, 1, 0, 0, # obs2 act2 1, 1, 2, 2, # obs3 ] def make_prediction(symbol): one_hot = np.eye(vocab_size)[symbol] log_probs = (1 - one_hot) * -100.0 # Virtually deterministic. # (4 obs symbols + 1 action symbol) * 3 timesteps = 15. return np.array([[log_probs]]) mock_predict_fn = mock.MagicMock() mock_predict_fn.side_effect = map(make_prediction, symbols) with math.use_backend('numpy'): # (model_params, opt_state) mock_restore_state.return_value.params = (None, None) env = self._make_env( predict_fn=mock_predict_fn, reward_fn=(lambda _1, _2: np.array([0.5])), done_fn=(lambda _1, _2: np.array([False])), vocab_size=vocab_size, batch_size=1, max_trajectory_length=3, observation_space=gym.spaces.Box(low=0, high=5, shape=(4, )), action_space=gym.spaces.MultiDiscrete(nvec=[2, 2]), ) def assert_input_suffix(expected_symbols): actual_symbols = np.array([ symbol.item() for ((symbol, ), _) in mock_predict_fn.call_args_list[-len(expected_symbols):] ]) np.testing.assert_array_equal(actual_symbols, expected_symbols) actions = [[0, 1], [1, 0]] obs1 = env.reset() assert_input_suffix(symbols[:3]) (obs2, reward, done, _) = env.step(np.array([actions[0]])) # Symbols going into the decoder when predicting the next observation are: # the last symbol of the previous observation, all action symbols, all # symbols but the last one of the next observation. assert_input_suffix([symbols[3]] + actions[0] + symbols[6:9]) self.assertFalse(np.array_equal(obs1, obs2)) np.testing.assert_array_equal(reward, [0.5]) np.testing.assert_array_equal(done, [False]) (obs3, reward, done, _) = env.step(np.array([actions[1]])) assert_input_suffix([symbols[9]] + actions[1] + symbols[12:15]) np.testing.assert_array_equal(obs1, obs3) np.testing.assert_array_equal(reward, [0.5]) np.testing.assert_array_equal(done, [True])
def test_reversible_swap(self, backend_name): with math.use_backend(backend_name): layer = tl.ReversibleSwap() xs = [np.array([1, 2]), np.array([10, 20])] ys = layer(xs) self.assertEqual(tl.to_list(ys), [[10, 20], [1, 2]])