Example #1
0
 def test_forward_dtype(self, backend, dtype):
     with math.use_backend(backend):
         layer = tl.BatchNorm()
         x = np.ones((3, 2, 7)).astype(dtype)
         _, _ = layer.init(shapes.signature(x))
         y = layer(x)
         self.assertEqual(y.dtype, dtype)
Example #2
0
    def test_no_int32_or_uint32_returned(self):
        """Tests that Trainer._jit_update_fn doesn't return int32 or uint32.

    TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled
    computation to copy int32/uint32 outputs to CPU. This test makes sure that
    won't happen.
    """
        if xla_bridge.device_count() > 1:
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with math.use_backend('tf'), self.tmp_dir() as output_dir:
            n_classes = 1001
            model_fn = functools.partial(models.Resnet50,
                                         n_output_classes=n_classes)
            inputs = test_inputs(n_classes, input_shape=(224, 224, 3))
            trainer = trainer_lib.Trainer(
                model=model_fn,
                loss_fn=layers.CrossEntropyLoss,
                optimizer=trax_opt.SM3,
                lr_schedule=lr.MultifactorSchedule,
                inputs=inputs,
            )
            trainer.reset(output_dir)
            trainer.train_epoch(1, 0)
            # Those are the things returned by Trainer._jit_update_fn
            arrays = (trainer._opt_state.weights, trainer._opt_state.slots,
                      trainer._model_state, trainer._rngs)
            arrays = tf.nest.flatten(arrays)
            for x in arrays:
                if isinstance(x, np.ndarray) and (x.dtype == np.int32
                                                  or x.dtype == np.uint32):
                    raise ValueError('Found an array of int32 or uint32: %s' %
                                     x)
Example #3
0
    def test_batching_lsh_self_attention(self):
        with math.use_backend('jax'):
            common_kwargs = dict(
                n_heads=6,
                d_qk=7,
                d_v=17,
                causal=True,
                chunk_len=5,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=2,
                n_buckets=4,
                attention_dropout=0.2,
                output_dropout=0.1,
                mode='train',
            )
            test_kwargs = []
            for n_parallel_heads in [1, 3, 6, 12]:
                for use_python_loop in [True, False]:
                    test_kwargs.append(
                        dict(n_parallel_heads=n_parallel_heads,
                             use_python_loop=use_python_loop))

            x = jax.random.uniform(jax.random.PRNGKey(0), (2, 10, 13),
                                   dtype=jnp.float32)
            input_signature = shapes.signature(x)
            self._test_equivalence_to_reference_code(
                efficient_attention.LSHSelfAttention, x, input_signature,
                common_kwargs, *test_kwargs)
Example #4
0
    def test_train_restart(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == 'tf':
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with math.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = test_inputs(n_classes)

            # Train and evaluate
            trainer_lib.train(output_dir,
                              model=model_fn,
                              inputs=inputs,
                              steps=steps,
                              eval_steps=eval_steps)

            # Restart training
            state = trainer_lib.train(output_dir,
                                      model=model_fn,
                                      inputs=inputs,
                                      steps=(2 * steps),
                                      eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(state.step, 2 * steps)
Example #5
0
    def test_batching_self_attention(self):
        with math.use_backend('jax'):
            common_kwargs = dict(
                n_heads=6,
                d_qk=7,
                d_v=17,
                share_qk=False,
                causal=True,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                attention_dropout=0.0,
                mode='train',
            )
            test_kwargs = []
            for n_parallel_heads in [1, 3, 6, 12]:
                for use_python_loop in [True, False]:
                    test_kwargs.append(
                        dict(n_parallel_heads=n_parallel_heads,
                             use_python_loop=use_python_loop))

            inp = jax.random.uniform(jax.random.PRNGKey(0), (2, 16, 64),
                                     dtype=np.float32)
            input_signature = ShapeDtype((2, 16, 64), dtype=np.float32)
            self._test_equivalence_to_reference_code(
                efficient_attention_v2.SelfAttention, inp, input_signature,
                common_kwargs, *test_kwargs)
Example #6
0
    def test_fast_inference_self_attention(self):
        with math.use_backend('jax'):
            common_kwargs = dict(
                n_heads=6,
                d_qk=7,
                d_v=17,
                share_qk=False,
                causal=True,
                chunk_len=5,
                n_chunks_before=1,
                n_chunks_after=0,
                attention_dropout=0.0,
                output_dropout=0.0,
            )
            test_kwargs = []
            for n_parallel_heads in [1, 3, 6, 12]:
                for use_python_loop in [True, False]:
                    test_kwargs.append(
                        dict(n_parallel_heads=n_parallel_heads,
                             use_python_loop=use_python_loop))

            x = jax.random.uniform(jax.random.PRNGKey(0), (2, 10, 13),
                                   dtype=jnp.float32)
            input_signature = shapes.signature(x)
            self._test_fast_inference(efficient_attention.SelfAttention, x,
                                      input_signature, common_kwargs,
                                      *test_kwargs)
Example #7
0
  def test_train_eval_predict_sm3(self, backend_name):
    if xla_bridge.device_count() > 1 and backend_name == 'tf':
      self.skipTest("tf-numpy backend doesn't support multi-devices yet.")
    with math.use_backend(backend_name), self.tmp_dir() as output_dir:
      # Prepare model and inputs
      n_classes = 4
      steps = 2
      eval_steps = 2
      model_fn = functools.partial(
          models.MLP, d_hidden=16, n_output_classes=n_classes)
      inputs = test_inputs(n_classes)

      # Train and evaluate
      state = trainer_lib.train(
          output_dir,
          model=model_fn,
          inputs=inputs,
          steps=steps,
          eval_steps=eval_steps,
          optimizer=trax_opt.SM3)

      # Assert total train steps
      self.assertEqual(steps, state.step)

      # Assert 2 evaluations ran
      train_acc = state.history.get('train', 'metrics/accuracy')
      eval_acc = state.history.get('eval', 'metrics/accuracy')
      self.assertEqual(len(train_acc), len(eval_acc))
      self.assertLen(eval_acc, 2)

      # Predict with weights loaded from file.
      inputs = inputs.train_stream(1)
      model = model_fn()
      model.init_from_file(os.path.join(output_dir, 'model.pkl'))
      model(next(inputs)[0])
Example #8
0
    def _test_lsh_self_attention_deterministic_given_seed(self, causal=False):
        # Once the initialization and the call seeds are pinned down we have
        # deterministic output.
        with math.use_backend('jax'):
            layer = efficient_attention.LSHSelfAttention(
                n_heads=5,
                d_qk=7,
                d_v=17,
                causal=causal,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=2,
                n_buckets=4,
                use_reference_code=True,
                attention_dropout=0.0,
                mode='train')
            x = np.ones((3, 32, 8)).astype(np.float32)

            def get_output():
                _, _ = layer.init(shapes.signature(x), jax.random.PRNGKey(0))
                return layer(x, rng=jax.random.PRNGKey(1))

            ys = [get_output() for _ in range(10)]

            self.assertEqual(ys[0].shape, x.shape)

            for y in ys[1:]:
                np.testing.assert_array_almost_equal(ys[0], y, decimal=6)
Example #9
0
    def test_reformer_rng_consistency(self):
        with math.use_backend('jax'):
            vocab_size = 16
            batch_size = 1
            input_sd = ShapeDtype((batch_size, 8), np.int32)
            input_signature = (input_sd, input_sd)
            model = reformer.ReformerLM(
                vocab_size,
                d_model=32,
                d_ff=64,
                d_attention_key=16,
                d_attention_value=16,
                n_layers=1,
                n_heads=2,
                max_len=16,
                n_chunks=2,
                n_attention_chunks=1,
                mode='train',
                attention_type=PoisonOnRNGMismatchAttention)

            rng = math.random.get_prng(0)
            weights, state = model.init(input_signature)

            def dummy_loss_fn(weights):
                inputs = (np.zeros(input_sd.shape, dtype=np.int32), ) * 2
                output = model(inputs, weights=weights, state=state, rng=rng)
                dummy_loss = math.numpy.sum(output[0])
                return dummy_loss

            grad_fn = math.grad(dummy_loss_fn)
            grads = grad_fn(weights)
            # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch.
            for grad in jax.tree_util.tree_leaves(grads):
                assert onp.all(onp.isfinite(grad))
Example #10
0
    def test_custom_zero_grad(self, backend_name):
        class IdWithZeroGrad(base.Layer):
            def forward(self, x, weights):
                return x

            @property
            def has_backward(self):
                return True

            def backward(self, inputs, output, grad, weights, state, new_state,
                         rng):
                return (jnp.zeros_like(grad), ())

        with math.use_backend(backend_name):
            layer = IdWithZeroGrad()
            rng = math.random.get_prng(0)
            input_signature = shapes.ShapeDtype((9, 17))
            random_input = math.random.uniform(rng,
                                               input_signature.shape,
                                               minval=-1.0,
                                               maxval=1.0)
            layer.init(input_signature)
            f = lambda x: jnp.mean(layer(x))
            grad = math.grad(f)(random_input)
            self.assertEqual(grad.shape, (9, 17))  # Gradient for each input.
            self.assertEqual(sum(sum(grad * grad)), 0.0)  # Each one is 0.
Example #11
0
    def _test_fast_inference(self, attention_type, length):
        with math.use_backend('jax'):
            vocab_size = 16
            model_fn = functools.partial(
                transformer.TransformerLM,
                vocab_size=vocab_size,
                d_model=4,
                d_ff=8,
                n_layers=2,
                n_heads=2,
                attention_type=attention_type,
            )
            model_slow = model_fn(mode='eval')
            model_fast = model_fn(mode='predict')
            rng = math.random.get_prng(0)
            batch_size = 2
            input_signature = ShapeDtype((batch_size, 1), np.int32)
            # Given the same rng, both models initialize with the same parameters.
            model_slow.init(input_signature)
            model_fast.init(input_signature)

            buf = onp.zeros((batch_size, length), dtype=np.int32)
            next_sym = onp.zeros((batch_size, 1), dtype=onp.int32)

            for index in range(length):
                logits_slow = model_slow(buf, rng=rng)
                logits_fast = model_fast(next_sym, rng=rng)
                onp.testing.assert_array_almost_equal(
                    logits_slow[:, index, :],
                    logits_fast[:, 0, :],
                    decimal=5,
                )
                next_sym = onp.random.randint(vocab_size, size=(batch_size, 1))
                buf[:, index] = next_sym[:, 0]
Example #12
0
 def nontrainable_params(self):
     # TODO(afrozm): Give further thought to this name.
     # TODO(lukaszkaiser): it makes no sense to use an accelerator (e.g. TPU)
     # in op-by-op mode just to compute the learning rate. However, there
     # should be a cleaner approach that forceably swapping out the backend.
     with math.use_backend('numpy'):
         return self._lr_fn(self._step)
Example #13
0
    def test_lsh_self_attention_masked_non_causal(self):
        # Test that when the input that is in the masked area changes the attention
        # for the un-masked outputs doesn't change, but the masked region does
        # change.
        with math.use_backend('jax'):
            layer = efficient_attention.LSHSelfAttention(
                n_heads=5,
                d_qk=7,
                d_v=17,
                causal=False,
                masked=True,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=2,
                n_buckets=4,
                use_reference_code=True,
                attention_dropout=0.0,
                mode='train')

            batch = 5
            max_len = 32
            hidden = 8

            x = np.random.uniform(size=(batch, max_len, hidden))
            mask = np.ones((batch, max_len)).astype(np.bool)
            rngs = jax.random.randint(jax.random.PRNGKey(0), (batch, ),
                                      minval=1,
                                      maxval=max_len - 1)

            # Set some suffix of each mask[b] to 0.
            for i in range(batch):
                mask[i, rngs[i]:] = 0

            # Fix rngs and get the output for the LSH layer.
            def get_output(x, mask):
                xs = [x, mask]
                _, _ = layer.init(shapes.signature(xs), jax.random.PRNGKey(0))
                return layer(xs, rng=jax.random.PRNGKey(1))

            # Get the attention output for masked x.
            y = get_output(x, mask)

            # Change x, but only in the masked regions.
            for i in range(batch):
                x[i, rngs[i]:] = np.random.uniform(size=(max_len - rngs[i],
                                                         hidden))

            y2 = get_output(x, mask)

            for i in range(batch):
                # y and y2 should be identical in the non-masked part.
                np.testing.assert_array_almost_equal(y[i, :rngs[i]],
                                                     y2[i, :rngs[i]],
                                                     decimal=6)

                # In the masked out part, they should be different.
                self.assertGreater(
                    np.mean(np.abs(y[i, rngs[i]:] - y2[i, rngs[i]:])), 1e-5)
Example #14
0
 def test_layer_norm_dtype(self, backend, dtype):
     with use_backend(backend):
         input_shape = (2, 3, 4)
         input_signature = ShapeDtype(input_shape, dtype)
         layer = normalization.LayerNorm()
         layer.init(input_signature)
         out = layer(onp.empty(input_shape, dtype=dtype))
         self.assertEqual(out.dtype, dtype)
    def test_takes_new_history(self):
        histories = np.array([[[0, 1, 2]], [[3, 4, 5]]])

        with math.use_backend('numpy'):
            env = self._create_env(  # pylint: disable=no-value-for-parameter
                model=mock.MagicMock(),
                histories=histories,
                trajectory_length=2,
            )
            env.reset()
            observation = env.reset()
            np.testing.assert_array_equal(observation, [5])
Example #16
0
 def test_fails_to_evaluate_model_with_matrix_observation_space(self):
   with math.use_backend('numpy'):
     env = self._make_env(  # pylint: disable=no-value-for-parameter
         observation_space=gym.spaces.Box(shape=(2, 2), low=0, high=1),
         action_space=gym.spaces.Discrete(n=1),
         max_trajectory_length=2,
         batch_size=1,
     )
     trajectories = [
         self._make_trajectory(np.array([[0, 1], [2, 3]]), np.array([0]))]
     metrics = simple.evaluate_model(env, trajectories, plt)
     self.assertIsNone(metrics)
Example #17
0
    def test_communicates_with_model(self):
        # Mock model increasing the observation by action, reward is the parity of
        # the new observation.
        def mock_transition(inputs, *args, **kwargs):
            del args
            del kwargs
            (observations, actions) = inputs
            new_observations = observations[:, -1] + actions
            rewards = np.array([[int(new_observations % 2 == 0)]])
            return (new_observations, rewards)

        mock_model_fn = mock.MagicMock()
        mock_model_fn.return_value.side_effect = mock_transition
        mock_model_fn.return_value.init.return_value = (None, None)
        mock_model = mock_model_fn.return_value

        actions_to_take = np.array([[1], [3]])
        histories = np.array([[[0, 1, 2, 3]]])
        expected_observations = np.array([[3], [4], [7]])
        expected_rewards = np.array([[1], [0]])
        expected_dones = np.array([[False], [True]])
        expected_histories = np.array([[[0, 1, 2, 3]], [[1, 2, 3, 4]]])
        expected_actions = actions_to_take

        with math.use_backend('numpy'):
            env = self._create_env(  # pylint: disable=no-value-for-parameter
                model=mock_model_fn,
                histories=histories,
                trajectory_length=len(actions_to_take),
            )
            actual_observations = [env.reset()]
            actual_rewards = []
            actual_dones = []
            actual_histories = []
            actual_actions = []
            for action in actions_to_take:
                (observation, reward, done, _) = env.step(action)
                actual_observations.append(observation)
                actual_rewards.append(reward)
                actual_dones.append(done)
                # Mock call is a tuple (args, kwargs). There is one positional argument,
                # which is a tuple (history, action).
                (((history, action), ), _) = mock_model.call_args
                actual_actions.append(action)
                actual_histories.append(history)

        np.testing.assert_array_equal(actual_observations,
                                      expected_observations)
        np.testing.assert_array_equal(actual_rewards, expected_rewards)
        np.testing.assert_array_equal(actual_dones, expected_dones)
        np.testing.assert_array_equal(actual_histories, expected_histories)
        np.testing.assert_array_equal(actual_actions, expected_actions)
Example #18
0
    def __init__(self,
                 trax_layer,
                 batch_size=None,
                 initializer_rng=None,
                 rng=None,
                 rng_updater=None,
                 dtype=None):
        """Creates a Keras layer wrapping around a Trax layer.

    Args:
      trax_layer: an object of class `trax.layers.Layer`, the trax layer to
        wrap.
      batch_size: (optional) an integer, the batch size that this Keras layer
        will be used on. Keras sometimes needs to generate a TF graph for a
        layer (e.g. for acceleration or checkpointing). The inputs used to trace
        the graph will have `None` as the length of their batch dimensions, so
        as to generate a graph that can handle any batch size. Some Trax layers
        can't handle tensors whose shapes contain `None`. If `batch_size` is set
        to an integer, the graph will be traced with `batch_size` as the batch
        size instead of `None`. Note that in this case the graph (and the Keras
        layer) can only be used on a specific batch size. If you want to use a
        different batch size, you need to create another `TraxKerasLayer` object
        with a different `batch_size`.
      initializer_rng: (optional) an RNG key used to create the weights and
        state if `trax_layer` doesn't have them. If `None`,
        `trax.math.random.get_prng(0)` will be used.
      rng: (optional) an RNG key for the forward function (aka the "forward
        key"). If `None`, `trax.math.random.get_prng(0)` will be used.
      rng_updater: (optional) a function of type rng_key -> rng_key, used to
        update the forward key after each forward pass. If `None`, the function
        `lambda x: trax.math.random.split(x, 1)[0]` will be used, which advances
        the RNG key.
      dtype: (optional) the dtype of the inputs. See the `dtype` argument of
        `tf.keras.layers.Layer.__init__` for details.
    """
        super(TraxKerasLayer, self).__init__(dtype=dtype)
        with math_lib.use_backend("tf"):
            if initializer_rng is None:
                initializer_rng = math_lib.random.get_prng(0)
            if rng is None:
                rng = math_lib.random.get_prng(0)
            if rng_updater is None:
                rng_updater = lambda x: math_lib.random.split(x, 1)[0]
            self._trax_layer = trax_layer
            self._batch_size = batch_size
            self._initializer_rng = initializer_rng
            self._forward_rng_init = rng
            self._rng_updater = rng_updater
 def test_self_attention(self):
     with math.use_backend('jax'):
         input_signature = ShapeDtype((3, 32, 8))
         layer = efficient_attention.SelfAttention(n_heads=5,
                                                   d_qk=7,
                                                   d_v=17,
                                                   share_qk=False,
                                                   causal=True,
                                                   chunk_len=8,
                                                   n_chunks_before=1,
                                                   n_chunks_after=0,
                                                   use_reference_code=True,
                                                   attention_dropout=0.0,
                                                   mode='train')
         final_shape = base.check_shape_agreement(layer, input_signature)
         self.assertEqual((3, 32, 8), final_shape)
Example #20
0
    def _test_train_eval_predict(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == 'tf':
            self.skipTest("tf-numpy backend does't support multi-devices yet.")
        with math.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            steps = 2
            eval_steps = 2

            # Adds Dropout and BatchNorm to test state handling.
            def model_fn(mode='train'):
                return layers.Serial(
                    layers.Dropout(mode=mode, rate=0.1),
                    layers.BatchNorm(mode=mode),
                    models.MLP(d_hidden=16,
                               n_output_classes=n_classes,
                               mode=mode))

            inputs = test_inputs(n_classes)

            # Train and evaluate
            state = trainer_lib.train(output_dir,
                                      model=model_fn,
                                      inputs=inputs,
                                      steps=steps,
                                      eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(steps, state.step)

            # Assert 2 evaluations ran
            train_acc = state.history.get('train', 'metrics/accuracy')
            eval_acc = state.history.get('eval', 'metrics/accuracy')
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertLen(eval_acc, 2)

            # Predict with final weights
            inputs = inputs.train_stream(1)
            model = model_fn()
            weights = state.opt_state.weights[0]
            state = state.model_state[0]
            if xla_bridge.device_count() > 1:
                unreplicate = lambda x: x[0]
                weights = math.nested_map(unreplicate, weights)
                state = math.nested_map(unreplicate, state)
            model(next(inputs)[0], weights=weights, state=state)
Example #21
0
 def test_self_attention_tf(self):
     with math.use_backend('tf'):
         layer = efficient_attention.SelfAttention(n_heads=5,
                                                   d_qk=7,
                                                   d_v=17,
                                                   share_qk=False,
                                                   causal=True,
                                                   chunk_len=8,
                                                   n_chunks_before=1,
                                                   n_chunks_after=0,
                                                   use_reference_code=True,
                                                   attention_dropout=0.0,
                                                   mode='train')
         x = np.ones((3, 32, 8)).astype(np.float32)
         _, _ = layer.init(shapes.signature(x))
         y = layer(x)
         self.assertEqual(y.shape, x.shape)
Example #22
0
 def call(self, inputs):
     with math_lib.use_backend("tf"):
         inputs = math_lib.nested_map(
             functools.partial(_replace_none_batch,
                               batch_size=self._batch_size), inputs)
         weights, state, rng = read_values(
             [self._weights, self._state, self._rng])
         inputs, weights, state, rng = to_arrays(
             [inputs, weights, state, rng])
         outputs, new_state = self._trax_layer.pure_fn(inputs,
                                                       weights=weights,
                                                       state=state,
                                                       rng=rng)
         tf.nest.map_structure(lambda v, t: v.assign(t), self._state,
                               new_state)
         self._rng.assign(self._rng_updater(rng))
         outputs = to_tensors(outputs)
         return outputs
Example #23
0
 def test_evaluates_model_with_vector_observation_space(self):
   with math.use_backend('numpy'):
     env = self._make_env(  # pylint: disable=no-value-for-parameter
         observation_space=gym.spaces.Box(shape=(2,), low=0, high=1),
         action_space=gym.spaces.Discrete(n=1),
         max_trajectory_length=2,
         batch_size=3,
     )
     trajectories = [
         self._make_trajectory(observations, actions)  # pylint: disable=g-complex-comprehension
         for (observations, actions) in [
             (np.array([[0, 1]]), np.array([0])),
             (np.array([[1, 2], [3, 4]]), np.array([0, 0])),
             (np.array([[1, 2], [3, 4], [5, 6]]), np.array([0, 0, 0])),
         ]
     ]
     metrics = simple.evaluate_model(env, trajectories, plt)
     self.assertIsNotNone(metrics)
     self.assertEqual(len(metrics), 2)
Example #24
0
 def test_reformer_lm_forward_shape_tf(self):
     with math.use_backend('tf'):
         vocab_size = 16
         timebin_attn = self._timebin_self_attention_fn(
             use_reference_code=True)
         model = reformer.ReformerLM(vocab_size,
                                     d_model=32,
                                     d_ff=64,
                                     d_attention_key=16,
                                     d_attention_value=16,
                                     n_layers=1,
                                     n_heads=2,
                                     max_len=64,
                                     attention_type=timebin_attn)
         xs = [
             np.ones((1, 64)).astype(np.int32),
             np.ones((1, 64)).astype(np.int32)
         ]
         _, _ = model.init(shapes.signature(xs))
         ys = model(xs)
         self.assertEqual([y.shape for y in ys], [(1, 64, 16), (1, 64)])
Example #25
0
  def test_reset_twice(self, backend_name):
    if xla_bridge.device_count() > 1 and backend_name == 'tf':
      self.skipTest("tf-numpy backend doesn't support multi-devices yet.")
    with math.use_backend(backend_name), self.tmp_dir() as output_dir1, \
          self.tmp_dir() as output_dir2:
      n_classes = 4
      model_fn = functools.partial(
          models.MLP, d_hidden=16, n_output_classes=n_classes)
      inputs = test_inputs(n_classes)

      trainer = trainer_lib.Trainer(
          model=model_fn,
          loss_fn=layers.CrossEntropyLoss(),
          optimizer=trax_opt.SM3,
          lr_schedule=lr.MultifactorSchedule,
          inputs=inputs,
      )

      trainer.reset(output_dir1)
      trainer.evaluate(1)
      trainer.reset(output_dir2)
      trainer.evaluate(1)
Example #26
0
 def build(self, input_shape):
     with math_lib.use_backend("tf"):
         # Using `is` instead of `==` following Trax's practice
         if self._trax_layer.weights is base.EMPTY_WEIGHTS:
             sanitized_input_shape = math_lib.nested_map(
                 functools.partial(_replace_none_batch,
                                   batch_size=self._batch_size),
                 input_shape)
             weights, state = self._trax_layer.init(
                 tensor_shapes_to_shape_dtypes(sanitized_input_shape,
                                               self.dtype),
                 rng=self._initializer_rng)
         else:
             weights = self._trax_layer.weights
             state = self._trax_layer.state
         # Note: `weights` may contain `EMPTY_WEIGHTS`
         self._weights = math_lib.nested_map(
             functools.partial(tf.Variable, trainable=True), weights)
         self._state = math_lib.nested_map(
             functools.partial(tf.Variable, trainable=False), state)
         self._rng = tf.Variable(self._forward_rng_init, trainable=False)
     super(TraxKerasLayer, self).build(input_shape)
Example #27
0
    def testTrain(self, layer_id, rng_updater_id, batch_size, trax_has_weights,
                  explicit_build, use_model):
        """Tests training (forward and backward pass) for TraxKerasLayer.

    Args:
      layer_id: an integer, the index into `_LAYERS`.
      rng_updater_id: an integer, the index into `_RNG_UPDATERS`.
      batch_size: an integer or `None`, the value for the `batch_size` argument
        in `TraxKerasLayer.__init__`.
      trax_has_weights: bool, whether to make the trax layer contain weights at
        the time when `TraxKerasLayer.build` is called.
      explicit_build: bool, whether to explicitly call `TraxKerasLayer.build`.
      use_model: bool, whether to build a `tf.keras.Model` out of the
        `TraxKerasLayer` layer and use the model to do the training instead of
        the bare layer. If `True`, we will also test checkpointing and restoring
        using the model.
    """
        with math_lib.use_backend("tf"):
            make_trax_layer, input_shapes_no_batch, dtype, allow_none_batch = (
                _LAYERS[layer_id])
            # We make a fresh trax layer for each test case, so that different test
            # cases won't interfere with each other.
            trax_layer = make_trax_layer()
            if not allow_none_batch and batch_size is None:
                self.skipTest("This Trax layer can't handle None batch size.")
            rng_updater = _RNG_UPDATERS[rng_updater_id]
            input_shapes = math_lib.nested_map(lambda s: [batch_size] + s,
                                               input_shapes_no_batch)
            input_sig = trax2keras.tensor_shapes_to_shape_dtypes(
                input_shapes, dtype)
            initializer_rng = math_lib.random.get_prng(765)
            weights, state = trax_layer.init(input_sig, rng=initializer_rng)
            generator = tf.random.Generator.from_seed(567)

            def get_inputs():
                return dummy_inputs(generator, input_sig)

            if trax_has_weights:
                trax_layer(to_arrays(get_inputs()),
                           weights=weights,
                           state=state)
            rng = math_lib.random.get_prng(1234)
            keras_layer = trax2keras.TraxKerasLayer(
                trax_layer,
                batch_size=batch_size,
                initializer_rng=initializer_rng,
                rng=rng,
                rng_updater=rng_updater)
            if explicit_build:
                keras_layer.build(input_shapes)
            if use_model:
                x = tf.keras.Input(shape=input_shapes_no_batch, dtype=dtype)
                y = keras_layer(x)
                keras_model = tf.keras.Model(inputs=x, outputs=y)
            lr = 0.1  # learning rate
            for _ in range(3):
                inputs = get_inputs()
                with tf.GradientTape() as trax_tape:
                    trax_tape.watch([x.data for x in tf.nest.flatten(weights)])
                    trax_outputs, state = trax_layer.pure_fn(to_arrays(inputs),
                                                             weights=weights,
                                                             state=state,
                                                             rng=rng)
                trax_grads = trax_tape.gradient(
                    *to_tensors([trax_outputs, weights]))
                # `g` may be `tf.IndexedSlices`, so we need to `convert_to_tensor`
                # before multiplication.
                weights = tf.nest.map_structure(
                    lambda w, g: w + np.asarray(lr * tf.convert_to_tensor(g), w
                                                .dtype), weights, trax_grads)
                rng = rng_updater(rng)
                with tf.GradientTape() as keras_tape:
                    if use_model:
                        keras_outputs = keras_model(inputs)
                    else:
                        keras_outputs = keras_layer(inputs)
                if isinstance(keras_outputs,
                              tuple) and len(keras_outputs) == 1:
                    keras_outputs = keras_outputs[0]
                self.assertAllClose(to_tensors(trax_outputs), keras_outputs)
                keras_grads = keras_tape.gradient(
                    keras_outputs, keras_layer.trainable_variables)
                tf.nest.map_structure(
                    lambda v, g: v.assign_add(  # pylint: disable=g-long-lambda
                        tf.cast(lr * tf.convert_to_tensor(g), v.dtype)),
                    keras_layer.trainable_variables,
                    keras_grads)
                self.assertAllClose(to_tensors(weights),
                                    read_values(keras_layer._weights),
                                    rtol=2e-6,
                                    atol=5e-5)
                self.assertAllClose(to_tensors(state),
                                    read_values(keras_layer._state))
                self.assertAllClose(to_tensors(rng),
                                    read_values(keras_layer._rng))
            if use_model:
                fname = os.path.join(self.get_temp_dir(), "checkpoint")
                keras_model.save(fname)
                loaded_model = tf.keras.models.load_model(fname)
                for _ in range(2):
                    inputs = get_inputs()
                    self.assertAllClose(keras_model(inputs),
                                        loaded_model(inputs))
    def test_communicates_with_model(self, mock_restore_state):
        gin.bind_parameter('BoxSpaceSerializer.precision', 1)
        vocab_size = 16
        # Mock model predicting a fixed sequence of symbols. It is made such that
        # the first two observations are different and the last one is equal to the
        # first.
        symbols = [
            1,
            1,
            2,
            2,
            0,
            0,  # obs1 act1
            1,
            2,
            2,
            1,
            0,
            0,  # obs2 act2
            1,
            1,
            2,
            2,  # obs3
        ]

        def make_prediction(symbol):
            one_hot = np.eye(vocab_size)[symbol]
            log_probs = (1 - one_hot) * -100.0  # Virtually deterministic.
            # (4 obs symbols + 1 action symbol) * 3 timesteps = 15.
            return np.array([[log_probs]])

        mock_predict_fn = mock.MagicMock()
        mock_predict_fn.side_effect = map(make_prediction, symbols)

        with math.use_backend('numpy'):
            # (model_params, opt_state)
            mock_restore_state.return_value.params = (None, None)
            env = self._make_env(
                predict_fn=mock_predict_fn,
                reward_fn=(lambda _1, _2: np.array([0.5])),
                done_fn=(lambda _1, _2: np.array([False])),
                vocab_size=vocab_size,
                batch_size=1,
                max_trajectory_length=3,
                observation_space=gym.spaces.Box(low=0, high=5, shape=(4, )),
                action_space=gym.spaces.MultiDiscrete(nvec=[2, 2]),
            )

            def assert_input_suffix(expected_symbols):
                actual_symbols = np.array([
                    symbol.item() for ((symbol, ), _) in
                    mock_predict_fn.call_args_list[-len(expected_symbols):]
                ])
                np.testing.assert_array_equal(actual_symbols, expected_symbols)

            actions = [[0, 1], [1, 0]]

            obs1 = env.reset()
            assert_input_suffix(symbols[:3])

            (obs2, reward, done, _) = env.step(np.array([actions[0]]))
            # Symbols going into the decoder when predicting the next observation are:
            # the last symbol of the previous observation, all action symbols, all
            # symbols but the last one of the next observation.
            assert_input_suffix([symbols[3]] + actions[0] + symbols[6:9])
            self.assertFalse(np.array_equal(obs1, obs2))
            np.testing.assert_array_equal(reward, [0.5])
            np.testing.assert_array_equal(done, [False])

            (obs3, reward, done, _) = env.step(np.array([actions[1]]))
            assert_input_suffix([symbols[9]] + actions[1] + symbols[12:15])
            np.testing.assert_array_equal(obs1, obs3)
            np.testing.assert_array_equal(reward, [0.5])
            np.testing.assert_array_equal(done, [True])
Example #29
0
 def test_reversible_swap(self, backend_name):
     with math.use_backend(backend_name):
         layer = tl.ReversibleSwap()
         xs = [np.array([1, 2]), np.array([10, 20])]
         ys = layer(xs)
         self.assertEqual(tl.to_list(ys), [[10, 20], [1, 2]])