コード例 #1
0
 def test_computes_basic_mean(self):
     inputs = [np.array([1, 2, 3])]
     targets = [np.zeros(3)]
     weights = [1]
     with backend.use_backend("numpy"):
         mean = trax.masked_mean(inputs, targets, weights)
         np.testing.assert_allclose(mean, 2)
コード例 #2
0
 def test_computes_mean_with_weights(self, backend_name):
     with backend.use_backend(backend_name):
         inputs = [np.array([1, 2, 3])]
         targets = [np.zeros(3)]
         weights = [np.array([3, 1, 0])]
         mean = trax.masked_mean(inputs, targets, weights)
         onp.testing.assert_allclose(mean, 1.25)
コード例 #3
0
ファイル: trax_test.py プロジェクト: tianhai123/-
    def test_train_with_weights(self, backend_name):
        if jax.lib.xla_bridge.device_count() > 1 and backend_name == "tf":
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            gin.bind_parameter("unpack_batch.has_weights", True)

            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes, with_weights=True)

            # Train and evaluate
            state = trax.train(output_dir,
                               model=model_fn,
                               inputs=inputs,
                               train_steps=train_steps,
                               eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(state.step, train_steps)
コード例 #4
0
 def test_computes_mean_with_mask(self):
     inputs = [np.array([1, 2, 3])]
     targets = [np.array([1, 0, 0])]
     weights = [1]
     with backend.use_backend("numpy"):
         mean = trax.masked_mean(inputs, targets, weights, mask_id=1)
         np.testing.assert_allclose(mean, 2.5)
コード例 #5
0
    def test_reformer_rng_consistency(self):
        with backend.use_backend('jax'):
            vocab_size = 16
            batch_size = 1
            input_shape = ((batch_size, 8), (batch_size, 8))
            model = reformer.ReformerLM(
                vocab_size,
                d_model=32,
                d_ff=64,
                d_attention_key=16,
                d_attention_value=16,
                n_layers=1,
                n_heads=2,
                max_len=16,
                n_chunks=2,
                n_attention_chunks=1,
                mode='train',
                attention_type=PoisonOnRNGMismatchAttention)

            rng = backend.random.get_prng(0)
            params, state = model.initialize_once(input_shape,
                                                  (np.int32, np.int32), rng)

            def dummy_loss_fn(params):
                inputs = (np.zeros(input_shape[0], dtype=np.int32), ) * 2
                output = model(inputs, params=params, state=state, rng=rng)
                dummy_loss = backend.numpy.sum(output[0])
                return dummy_loss

            grad_fn = backend.grad(dummy_loss_fn)
            grads = grad_fn(params)
            # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch.
            for grad in jax.tree_util.tree_leaves(grads):
                assert onp.all(onp.isfinite(grad))
コード例 #6
0
    def test_train_eval_predict_sm3(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == "tf":
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            state = trax.train(output_dir,
                               model=model_fn,
                               inputs=inputs,
                               train_steps=train_steps,
                               eval_steps=eval_steps,
                               optimizer=trax_opt.SM3)

            # Assert total train steps
            self.assertEqual(train_steps, state.step)

            # Assert 2 evaluations ran
            train_acc = state.history.get("train", "metrics/accuracy")
            eval_acc = state.history.get("eval", "metrics/accuracy")
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertLen(eval_acc, 2)

            # Predict with final params
            inputs = inputs(1).train_stream()
            model = layers.Serial(model_fn())
            model(next(inputs)[0], params=state.opt_state.params)
コード例 #7
0
    def test_train_restart(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == "tf":
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            trax.train(output_dir,
                       model=model_fn,
                       inputs=inputs,
                       train_steps=train_steps,
                       eval_steps=eval_steps)

            # Restart training
            state = trax.train(output_dir,
                               model=model_fn,
                               inputs=inputs,
                               train_steps=(2 * train_steps),
                               eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(state.step, 2 * train_steps)
コード例 #8
0
    def pseudo_call(self, pseudo_inputs, params):
        """Computes shapes and types this layer would produce for the given inputs.

    Args:
      pseudo_inputs: A ShapeType instance (input data minus the actual values)
          or a tuple of ShapeType instances, following the same conventions as
          Layer.call's input arg.
      params: Parameters for this layer.

    Returns:
      A ShapeType instance representing the shape and type of the output (if
      this layer has one output) or a tuple of ShapeType instances (if this
      layer has more than one output).
    """
        try:
            with backend.use_backend('jax'):
                # Beware: using an actual RNG (as opposed to this ShapeType stub) would
                # cause a large number of dropout masks to be computed and permanently
                # stored in global memory.
                rng = ShapeType(shape=(2, ), dtype=onp.uint32)

                def call_on_input(x, params, rng):
                    return self.call(x, params=params, rng=rng)

                params_shapes = nested_map(
                    params, lambda x: ShapeType(shape=x.shape, dtype=x.dtype))
                s = _eval_on_shapes(call_on_input, pseudo_inputs,
                                    params_shapes, rng)
            return s
        except Exception:
            name, trace = self.__class__.__name__, _short_traceback(skip=3)
            raise LayerError(name, 'pseudo_call', self._caller, pseudo_inputs,
                             trace)
コード例 #9
0
ファイル: base.py プロジェクト: samprasgit/tensor2tensor
  def output_shape(self, input_shape_and_type, params):
    """Output shape and type for this layer given input shape and type.

    Note that all arguments and return values can be tuples or dictionaries
    or arbitrary nested structures composed of tuples and dictionaries.

    Args:
      input_shape_and_type: a ShapeType with shape and type of the input.
      params: parameters for this layer.

    Returns:
      The shape and type of the output.
    """
    try:
      with backend.use_backend('jax'):
        rng = backend.random.get_prng(0)
        def call_on_input(x, params):
          f = lambda y: self.call(y, params=params, rng=rng)
          n = self.stack_items_to_pass() if isinstance(x, (list, tuple)) else 0
          return _apply_to_first_n(f, x, n)
        params_shapes = nested_map(
            params, lambda x: ShapeType(shape=x.shape, tp=x.dtype))
        s = _eval_on_shapes(call_on_input, input_shape_and_type, params_shapes)
      return s
    except Exception:
      name, trace = self.__class__.__name__, _short_traceback(skip=3)
      raise LayerError(name, 'output_shape', self._caller,
                       input_shape_and_type, trace)
コード例 #10
0
    def test_transformer_lm_fast_inference(self):
        with backend.use_backend('jax'):
            vocab_size = 16
            model_fn = functools.partial(transformer.TransformerLM,
                                         vocab_size=vocab_size,
                                         d_model=4,
                                         d_ff=8,
                                         n_layers=2,
                                         n_heads=2)
            model_slow = model_fn(mode='eval')
            model_fast = model_fn(mode='predict')
            rng = backend.random.get_prng(0)
            batch_size = 2
            _, _ = model_slow.initialize_once((batch_size, 1), np.int32, rng)
            _, _ = model_fast.initialize_once((batch_size, 1), np.int32, rng)

            max_length = 5
            buf = onp.zeros((batch_size, max_length), dtype=np.int32)
            next_sym = onp.zeros((batch_size, 1), dtype=onp.int32)

            for index in range(max_length):
                logits_slow = model_slow(buf, rng=rng)
                logits_fast = model_fast(next_sym, rng=rng)
                onp.testing.assert_array_almost_equal(logits_slow[:, index, :],
                                                      logits_fast[:, 0, :])
                next_sym = onp.random.randint(vocab_size, size=(batch_size, 1))
                buf[:, index] = next_sym[:, 0]
コード例 #11
0
 def test_computes_mean_with_weights_and_mask(self, backend_name):
     with backend.use_backend(backend_name):
         inputs = [np.array([1, 2, 4])]
         targets = [np.array([1, 0, 0])]
         weights = [np.array([10, 4, 1])]
         mean = trax.masked_mean(inputs, targets, weights, mask_id=1)
         onp.testing.assert_allclose(mean, 2.4)
コード例 #12
0
    def _test_fast_inference(self, attention_type, length):
        with backend.use_backend('jax'):
            vocab_size = 16
            model_fn = functools.partial(
                transformer.TransformerLM,
                vocab_size=vocab_size,
                d_model=4,
                d_ff=8,
                n_layers=2,
                n_heads=2,
                attention_type=attention_type,
            )
            model_slow = model_fn(mode='eval')
            model_fast = model_fn(mode='predict')
            rng = backend.random.get_prng(0)
            batch_size = 2
            # Given the same rng, both models initialize with the same parameters.
            model_slow.initialize_once((batch_size, 1), np.int32, rng)
            model_fast.initialize_once((batch_size, 1), np.int32, rng)

            buf = onp.zeros((batch_size, length), dtype=np.int32)
            next_sym = onp.zeros((batch_size, 1), dtype=onp.int32)

            for index in range(length):
                logits_slow = model_slow(buf, rng=rng)
                logits_fast = model_fast(next_sym, rng=rng)
                onp.testing.assert_array_almost_equal(logits_slow[:, index, :],
                                                      logits_fast[:, 0, :])
                next_sym = onp.random.randint(vocab_size, size=(batch_size, 1))
                buf[:, index] = next_sym[:, 0]
コード例 #13
0
ファイル: base.py プロジェクト: dbs700/tensor2tensor
    def pseudo_call(self, pseudo_input, params):
        """Computes what shapes and types this layer would produce for given input.

    Args:
      pseudo_input: A ShapeType instance (input data minus the actual values)
          or a tuple of ShapeType instances.
      params: Parameters for this layer.

    Returns:
      A ShapeType instance representing the shape and type of the output (if
      this layer has one output) or a tuple of ShapeType instances (if this
      layer has more than one output).
    """
        try:
            with backend.use_backend('jax'):
                rng = backend.random.get_prng(0)

                def call_on_input(x, params):
                    f = lambda y: self.call(y, params=params, rng=rng)
                    n = self.stack_items_to_pass() if isinstance(
                        x, (list, tuple)) else 0
                    return _apply_to_first_n(f, x, n)

                params_shapes = nested_map(
                    params, lambda x: ShapeType(shape=x.shape, dtype=x.dtype))
                s = _eval_on_shapes(call_on_input, pseudo_input, params_shapes)
            return s
        except Exception:
            name, trace = self.__class__.__name__, _short_traceback(skip=3)
            raise LayerError(name, 'pseudo_call', self._caller, pseudo_input,
                             trace)
コード例 #14
0
ファイル: base.py プロジェクト: silasjeon2/tensor2tensor
    def pseudo_call(self, pseudo_inputs, params):
        """Computes shapes and types this layer would produce for the given inputs.

    Args:
      pseudo_inputs: A ShapeType instance (input data minus the actual values)
          or a tuple of ShapeType instances, following the same conventions as
          Layer.call's input arg.
      params: Parameters for this layer.

    Returns:
      A ShapeType instance representing the shape and type of the output (if
      this layer has one output) or a tuple of ShapeType instances (if this
      layer has more than one output).
    """
        try:
            with backend.use_backend('jax'):
                # Same as backend.random.get_prng(0), but no op-by-op execution.
                rng = onp.zeros(2, onp.uint32)

                def call_on_input(x, params):
                    return self.call(x, params=params, rng=rng)

                params_shapes = nested_map(
                    params, lambda x: ShapeType(shape=x.shape, dtype=x.dtype))
                s = _eval_on_shapes(call_on_input, pseudo_inputs,
                                    params_shapes)
            return s
        except Exception:
            name, trace = self.__class__.__name__, _short_traceback(skip=3)
            raise LayerError(name, 'pseudo_call', self._caller, pseudo_inputs,
                             trace)
コード例 #15
0
  def test_communicates_with_model(self, mock_restore_state):
    gin.bind_parameter("BoxSpaceSerializer.precision", 1)
    vocab_size = 16
    # Mock model predicting a fixed sequence of symbols. It is made such that
    # the first two observations are equal and the last one is different.
    symbols = [
        1, 1, 2, 2,  # obs1
        1, 1, 2, 2,  # obs2
        1, 2, 2, 1,  # obs3
    ]
    def make_prediction(symbol):
      one_hot = np.eye(vocab_size)[symbol]
      log_probs = (1 - one_hot) * -100.0  # Virtually deterministic.
      # (4 obs symbols + 1 action symbol) * 3 timesteps = 15.
      return np.array([[log_probs] * 15])

    mock_model_fn = mock.MagicMock()
    mock_model = mock_model_fn.return_value
    mock_model.side_effect = map(make_prediction, symbols)

    with backend.use_backend("numpy"):
      # (model_params, opt_state)
      mock_restore_state.return_value.params = (None, None)
      env = simulated_env_problem.SerializedSequenceSimulatedEnvProblem(
          model=mock_model_fn,
          reward_fn=(lambda _1, _2: np.array([0.5])),
          done_fn=(lambda _1, _2: np.array([False])),
          vocab_size=vocab_size,
          max_trajectory_length=3,
          batch_size=1,
          observation_space=gym.spaces.Box(low=0, high=5, shape=(4,)),
          action_space=gym.spaces.Discrete(2),
          reward_range=(-1, 1),
          discrete_rewards=False,
          history_stream=itertools.repeat(None),
          output_dir=None,
      )
      obs1 = env.reset()
      ((inputs,), _) = mock_model.call_args

      act1 = 0
      (obs2, reward, done, _) = env.step(np.array([act1]))
      ((inputs,), _) = mock_model.call_args
      self.assertEqual(inputs[0, 4], act1)
      np.testing.assert_array_equal(inputs[0, :4], symbols[:4])
      np.testing.assert_array_equal(obs1, obs2)
      np.testing.assert_array_equal(reward, [0.5])
      np.testing.assert_array_equal(done, [False])

      act2 = 1
      (obs3, reward, done, _) = env.step(np.array([act2]))
      ((inputs,), _) = mock_model.call_args
      self.assertEqual(inputs[0, 9], act2)
      np.testing.assert_array_equal(inputs[0, 5:9], symbols[4:8])
      self.assertFalse(np.array_equal(obs2, obs3))
      np.testing.assert_array_equal(reward, [0.5])
      np.testing.assert_array_equal(done, [False])
コード例 #16
0
  def test_takes_new_history(self):
    histories = np.array([[[0, 1, 2]], [[3, 4, 5]]])

    with backend.use_backend("numpy"):
      env = self._create_env(  # pylint: disable=no-value-for-parameter
          model=mock.MagicMock(),
          histories=histories,
          trajectory_length=2,
      )
      env.reset()
      observation = env.reset()
      np.testing.assert_array_equal(observation, [5])
コード例 #17
0
    def train_epoch(self, epoch_steps, eval_steps):
        """Train for one epoch."""
        # Log separator
        print()

        # Timer
        start_time = time.time()

        for _ in range(epoch_steps):
            # Train
            next_train_batch = next(self._train_stream)
            if self._n_devices > 1:  # TODO(lukaszkaiser): use everywhere if possible.
                next_train_batch = reshape_by_device(next_train_batch,
                                                     self._n_devices)
            self._opt_state, self._rngs = self._jit_update_fn(
                self._step, self._opt_state, next_train_batch, self._rngs)
            self._step += 1

            if self._step in self._save_steps:
                _save_replicated(self._opt_state, self._step, self._history,
                                 self._n_devices, self._output_dir, True)

            # LR log
            if self._step == 1 or self._step % 10 == 0:
                # TODO(lukaszkaiser): it makes no sense to use an accelerator (e.g. TPU)
                # in op-by-op mode just to compute the learning rate. However, there
                # should be a cleaner approach that forceably swapping out the backend.
                with backend.use_backend("numpy"):
                    self._train_sw.scalar("training/learning rate",
                                          self._lr_fn(self._step),
                                          step=self._step)

        # Timer
        epoch_time = time.time() - start_time
        step_log(
            self._step,
            "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time))
        if epoch_steps > 1:
            self._train_sw.scalar("training/steps per second",
                                  epoch_steps / epoch_time,
                                  step=self._step)

        # Evaluate in parallel
        self.evaluate(eval_steps)

        # Save state
        _save_replicated(self._opt_state, self._step, self._history,
                         self._n_devices, self._output_dir, False)

        # Flush summary writers
        self._train_sw.flush()
        self._eval_sw.flush()
コード例 #18
0
    def test_communicates_with_model(self):
        # Mock model increasing the observation by action, reward is the parity of
        # the new observation.
        def mock_transition(inputs, *args, **kwargs):
            del args
            del kwargs
            (observations, actions) = inputs
            new_observations = observations[:, -1] + actions
            rewards = np.array([[int(new_observations % 2 == 0)]])
            return (new_observations, rewards)

        mock_model_fn = mock.MagicMock()
        mock_model_fn.return_value.side_effect = mock_transition
        mock_model = mock_model_fn.return_value

        actions_to_take = np.array([[1], [3]])
        initial_observations = np.array([[[0, 1, 2, 3]]])
        expected_observations = np.array([[3], [4], [7]])
        expected_rewards = np.array([[1], [0]])
        expected_dones = np.array([[False], [True]])
        expected_histories = np.array([[[0, 1, 2, 3]], [[1, 2, 3, 4]]])
        expected_actions = actions_to_take

        with backend.use_backend("numpy"):
            env = self._create_env(  # pylint: disable=no-value-for-parameter
                model=mock_model_fn,
                initial_observations=initial_observations,
                trajectory_length=len(actions_to_take),
            )
            actual_observations = [env.reset()]
            actual_rewards = []
            actual_dones = []
            actual_histories = []
            actual_actions = []
            for action in actions_to_take:
                (observation, reward, done, _) = env.step(action)
                actual_observations.append(observation)
                actual_rewards.append(reward)
                actual_dones.append(done)
                # Mock call is a tuple (args, kwargs). There is one positional argument,
                # which is a tuple (history, action).
                (((history, action), ), _) = mock_model.call_args
                actual_actions.append(action)
                actual_histories.append(history)

        np.testing.assert_array_equal(actual_observations,
                                      expected_observations)
        np.testing.assert_array_equal(actual_rewards, expected_rewards)
        np.testing.assert_array_equal(actual_dones, expected_dones)
        np.testing.assert_array_equal(actual_histories, expected_histories)
        np.testing.assert_array_equal(actual_actions, expected_actions)
コード例 #19
0
 def test_fails_to_evaluate_model_with_matrix_observation_space(self):
     with backend.use_backend("numpy"):
         env = self._make_env(  # pylint: disable=no-value-for-parameter
             observation_space=gym.spaces.Box(shape=(2, 2), low=0, high=1),
             action_space=gym.spaces.Discrete(n=1),
             max_trajectory_length=2,
             batch_size=1,
         )
         trajectories = [
             self._make_trajectory(np.array([[0, 1], [2, 3]]),
                                   np.array([0]))
         ]
         metrics = simple.evaluate_model(env, trajectories, plt)
         self.assertIsNone(metrics)
コード例 #20
0
 def test_evaluates_model_with_vector_observation_space(self):
     with backend.use_backend("numpy"):
         env = self._make_env(  # pylint: disable=no-value-for-parameter
             observation_space=gym.spaces.Box(shape=(2, ), low=0, high=1),
             action_space=gym.spaces.Discrete(n=1),
             max_trajectory_length=2,
             batch_size=3,
         )
         trajectories = [
             self._make_trajectory(observations, actions)  # pylint: disable=g-complex-comprehension
             for (observations, actions) in [
                 (np.array([[0, 1]]), np.array([0])),
                 (np.array([[1, 2], [3, 4]]), np.array([0, 0])),
                 (np.array([[1, 2], [3, 4], [5, 6]]), np.array([0, 0, 0])),
             ]
         ]
         metrics = simple.evaluate_model(env, trajectories, plt)
         self.assertIsNotNone(metrics)
         self.assertEqual(len(metrics), 2)
コード例 #21
0
    def test_train_eval_predict(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == "tf":
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2

            # Adds Dropout and BatchNorm to test state handling.
            def model_fn(mode="train"):
                return layers.Model(
                    layers.Dropout(mode=mode, rate=0.1),
                    layers.BatchNorm(mode=mode),
                    models.MLP(d_hidden=16,
                               n_output_classes=n_classes,
                               mode=mode))

            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            state = trax.train(output_dir,
                               model=model_fn,
                               inputs=inputs,
                               train_steps=train_steps,
                               eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(train_steps, state.step)

            # Assert 2 evaluations ran
            train_acc = state.history.get("train", "metrics/accuracy")
            eval_acc = state.history.get("eval", "metrics/accuracy")
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertLen(eval_acc, 2)

            # Predict with final params
            inputs = inputs(1).train_stream()
            model = layers.Serial(model_fn())
            model(next(inputs)[0], params=state.opt_state.params)
コード例 #22
0
ファイル: trax_test.py プロジェクト: TomNong/tensor2tensor
  def test_reset_twice(self, backend_name):
    if xla_bridge.device_count() > 1 and backend_name == "tf":
      self.skipTest("tf-numpy backend doesn't support multi-devices yet.")
    with backend.use_backend(backend_name), self.tmp_dir() as output_dir1, \
          self.tmp_dir() as output_dir2:
      n_classes = 4
      model_fn = functools.partial(models.MLP,
                                   d_hidden=16,
                                   n_output_classes=n_classes)
      inputs = lambda _: test_inputs(n_classes)

      trainer = trax.Trainer(
          model=model_fn,
          loss_fn=trax.loss,
          optimizer=trax_opt.SM3,
          lr_schedule=lr.MultifactorSchedule,
          inputs=inputs,
      )

      trainer.reset(output_dir1)
      trainer.evaluate(1)
      trainer.reset(output_dir2)
      trainer.evaluate(1)
コード例 #23
0
 def optimizer_params(self):
     # TODO(lukaszkaiser): it makes no sense to use an accelerator (e.g. TPU)
     # in op-by-op mode just to compute the learning rate. However, there
     # should be a cleaner approach that forceably swapping out the backend.
     with backend.use_backend("numpy"):
         return self._lr_fn(self._step)
コード例 #24
0
  def test_communicates_with_model(self, mock_restore_state):
    gin.bind_parameter("BoxSpaceSerializer.precision", 1)
    vocab_size = 16
    # Mock model predicting a fixed sequence of symbols. It is made such that
    # the first two observations are different and the last one is equal to the
    # first.
    symbols = [
        1, 1, 2, 2, 0, 0,  # obs1 act1
        1, 2, 2, 1, 0, 0,  # obs2 act2
        1, 1, 2, 2,        # obs3
    ]
    def make_prediction(symbol):
      one_hot = np.eye(vocab_size)[symbol]
      log_probs = (1 - one_hot) * -100.0  # Virtually deterministic.
      # (4 obs symbols + 1 action symbol) * 3 timesteps = 15.
      return np.array([[log_probs]])

    mock_predict_fn = mock.MagicMock()
    mock_predict_fn.side_effect = map(make_prediction, symbols)

    with backend.use_backend("numpy"):
      # (model_params, opt_state)
      mock_restore_state.return_value.params = (None, None)
      env = self._make_env(
          predict_fn=mock_predict_fn,
          reward_fn=(lambda _1, _2: np.array([0.5])),
          done_fn=(lambda _1, _2: np.array([False])),
          vocab_size=vocab_size,
          batch_size=1,
          max_trajectory_length=3,
          observation_space=gym.spaces.Box(low=0, high=5, shape=(4,)),
          action_space=gym.spaces.MultiDiscrete(nvec=[2, 2]),
      )

      def assert_input_suffix(expected_symbols):
        actual_symbols = np.array([
            symbol.item() for ((symbol,), _) in mock_predict_fn.call_args_list[
                -len(expected_symbols):
            ]
        ])
        np.testing.assert_array_equal(actual_symbols, expected_symbols)

      actions = [[0, 1], [1, 0]]

      obs1 = env.reset()
      assert_input_suffix(symbols[:3])

      (obs2, reward, done, _) = env.step(np.array([actions[0]]))
      # Symbols going into the decoder when predicting the next observation are:
      # the last symbol of the previous observation, all action symbols, all
      # symbols but the last one of the next observation.
      assert_input_suffix([symbols[3]] + actions[0] + symbols[6:9])
      self.assertFalse(np.array_equal(obs1, obs2))
      np.testing.assert_array_equal(reward, [0.5])
      np.testing.assert_array_equal(done, [False])

      (obs3, reward, done, _) = env.step(np.array([actions[1]]))
      assert_input_suffix([symbols[9]] + actions[1] + symbols[12:15])
      np.testing.assert_array_equal(obs1, obs3)
      np.testing.assert_array_equal(reward, [0.5])
      np.testing.assert_array_equal(done, [True])