Esempio n. 1
0
    def test_train_restart(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == 'tf':
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            trainer_lib.train(output_dir,
                              model=model_fn,
                              inputs=inputs,
                              train_steps=train_steps,
                              eval_steps=eval_steps)

            # Restart training
            state = trainer_lib.train(output_dir,
                                      model=model_fn,
                                      inputs=inputs,
                                      train_steps=(2 * train_steps),
                                      eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(state.step, 2 * train_steps)
Esempio n. 2
0
    def _test_fast_inference(self, attention_type, length):
        with backend.use_backend('jax'):
            vocab_size = 16
            model_fn = functools.partial(
                transformer.TransformerLM,
                vocab_size=vocab_size,
                d_model=4,
                d_ff=8,
                n_layers=2,
                n_heads=2,
                attention_type=attention_type,
            )
            model_slow = model_fn(mode='eval')
            model_fast = model_fn(mode='predict')
            rng = backend.random.get_prng(0)
            batch_size = 2
            # Given the same rng, both models initialize with the same parameters.
            model_slow.initialize_once((batch_size, 1), np.int32, rng)
            model_fast.initialize_once((batch_size, 1), np.int32, rng)

            buf = onp.zeros((batch_size, length), dtype=np.int32)
            next_sym = onp.zeros((batch_size, 1), dtype=onp.int32)

            for index in range(length):
                logits_slow = model_slow(buf, rng=rng)
                logits_fast = model_fast(next_sym, rng=rng)
                onp.testing.assert_array_almost_equal(logits_slow[:, index, :],
                                                      logits_fast[:, 0, :])
                next_sym = onp.random.randint(vocab_size, size=(batch_size, 1))
                buf[:, index] = next_sym[:, 0]
Esempio n. 3
0
 def nontrainable_params(self):
     # TODO(afrozm): Give further thought to this name.
     # TODO(lukaszkaiser): it makes no sense to use an accelerator (e.g. TPU)
     # in op-by-op mode just to compute the learning rate. However, there
     # should be a cleaner approach that forceably swapping out the backend.
     with backend.use_backend('numpy'):
         return self._lr_fn(self._step)
Esempio n. 4
0
    def test_train_eval_predict_sm3(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == 'tf':
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            state = trainer_lib.train(output_dir,
                                      model=model_fn,
                                      inputs=inputs,
                                      train_steps=train_steps,
                                      eval_steps=eval_steps,
                                      optimizer=trax_opt.SM3)

            # Assert total train steps
            self.assertEqual(train_steps, state.step)

            # Assert 2 evaluations ran
            train_acc = state.history.get('train', 'metrics/accuracy')
            eval_acc = state.history.get('eval', 'metrics/accuracy')
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertLen(eval_acc, 2)

            # Predict with final params
            inputs = inputs(1).train_stream()
            model = layers.Serial(model_fn())
            model(next(inputs)[0], params=state.opt_state.params)
Esempio n. 5
0
    def test_reformer_rng_consistency(self):
        with backend.use_backend('jax'):
            vocab_size = 16
            batch_size = 1
            input_sd = ShapeDtype((batch_size, 8), np.int32)
            input_signature = (input_sd, input_sd)
            model = reformer.ReformerLM(
                vocab_size,
                d_model=32,
                d_ff=64,
                d_attention_key=16,
                d_attention_value=16,
                n_layers=1,
                n_heads=2,
                max_len=16,
                n_chunks=2,
                n_attention_chunks=1,
                mode='train',
                attention_type=PoisonOnRNGMismatchAttention)

            rng = backend.random.get_prng(0)
            weights, state = model.initialize_once(input_signature)

            def dummy_loss_fn(weights):
                inputs = (np.zeros(input_sd.shape, dtype=np.int32), ) * 2
                output = model(inputs, weights=weights, state=state, rng=rng)
                dummy_loss = backend.numpy.sum(output[0])
                return dummy_loss

            grad_fn = backend.grad(dummy_loss_fn)
            grads = grad_fn(weights)
            # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch.
            for grad in jax.tree_util.tree_leaves(grads):
                assert onp.all(onp.isfinite(grad))
Esempio n. 6
0
  def test_no_int32_or_uint32_returned(self):
    """Tests that Trainer._jit_update_fn doesn't return int32 or uint32.

    TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled
    computation to copy int32/uint32 outputs to CPU. This test makes sure that
    won't happen.
    """
    if xla_bridge.device_count() > 1:
      self.skipTest("tf-numpy backend doesn't support multi-devices yet.")
    with backend.use_backend('tf'), self.tmp_dir() as output_dir:
      n_classes = 1001
      model_fn = functools.partial(models.Resnet50,
                                   n_output_classes=n_classes)
      inputs = lambda _: test_inputs(n_classes, input_shape=(224, 224, 3))
      trainer = trainer_lib.Trainer(
          model=model_fn,
          loss_fn=layers.CrossEntropyLossScalar,
          optimizer=trax_opt.SM3,
          lr_schedule=lr.MultifactorSchedule,
          inputs=inputs,
      )
      trainer.reset(output_dir)
      trainer.train_epoch(1, 0)
      # Those are the things returned by Trainer._jit_update_fn
      arrays = (trainer._opt_state.weights, trainer._opt_state.slots,
                trainer._model_state, trainer._rngs)
      arrays = tf.nest.flatten(arrays)
      for x in arrays:
        if isinstance(x, np.ndarray) and (x.dtype == np.int32 or
                                          x.dtype == np.uint32):
          raise ValueError('Found an array of int32 or uint32: %s' % x)
Esempio n. 7
0
 def test_fails_to_evaluate_model_with_matrix_observation_space(self):
   with backend.use_backend('numpy'):
     env = self._make_env(  # pylint: disable=no-value-for-parameter
         observation_space=gym.spaces.Box(shape=(2, 2), low=0, high=1),
         action_space=gym.spaces.Discrete(n=1),
         max_trajectory_length=2,
         batch_size=1,
     )
     trajectories = [
         self._make_trajectory(np.array([[0, 1], [2, 3]]), np.array([0]))]
     metrics = simple.evaluate_model(env, trajectories, plt)
     self.assertIsNone(metrics)
    def test_takes_new_history(self):
        histories = np.array([[[0, 1, 2]], [[3, 4, 5]]])

        with backend.use_backend('numpy'):
            env = self._create_env(  # pylint: disable=no-value-for-parameter
                model=mock.MagicMock(),
                histories=histories,
                trajectory_length=2,
            )
            env.reset()
            observation = env.reset()
            np.testing.assert_array_equal(observation, [5])
    def test_communicates_with_model(self):
        # Mock model increasing the observation by action, reward is the parity of
        # the new observation.
        def mock_transition(inputs, *args, **kwargs):
            del args
            del kwargs
            (observations, actions) = inputs
            new_observations = observations[:, -1] + actions
            rewards = np.array([[int(new_observations % 2 == 0)]])
            return (new_observations, rewards)

        mock_model_fn = mock.MagicMock()
        mock_model_fn.return_value.side_effect = mock_transition
        mock_model = mock_model_fn.return_value

        actions_to_take = np.array([[1], [3]])
        histories = np.array([[[0, 1, 2, 3]]])
        expected_observations = np.array([[3], [4], [7]])
        expected_rewards = np.array([[1], [0]])
        expected_dones = np.array([[False], [True]])
        expected_histories = np.array([[[0, 1, 2, 3]], [[1, 2, 3, 4]]])
        expected_actions = actions_to_take

        with backend.use_backend('numpy'):
            env = self._create_env(  # pylint: disable=no-value-for-parameter
                model=mock_model_fn,
                histories=histories,
                trajectory_length=len(actions_to_take),
            )
            actual_observations = [env.reset()]
            actual_rewards = []
            actual_dones = []
            actual_histories = []
            actual_actions = []
            for action in actions_to_take:
                (observation, reward, done, _) = env.step(action)
                actual_observations.append(observation)
                actual_rewards.append(reward)
                actual_dones.append(done)
                # Mock call is a tuple (args, kwargs). There is one positional argument,
                # which is a tuple (history, action).
                (((history, action), ), _) = mock_model.call_args
                actual_actions.append(action)
                actual_histories.append(history)

        np.testing.assert_array_equal(actual_observations,
                                      expected_observations)
        np.testing.assert_array_equal(actual_rewards, expected_rewards)
        np.testing.assert_array_equal(actual_dones, expected_dones)
        np.testing.assert_array_equal(actual_histories, expected_histories)
        np.testing.assert_array_equal(actual_actions, expected_actions)
Esempio n. 10
0
 def test_evaluates_model_with_vector_observation_space(self):
   with backend.use_backend('numpy'):
     env = self._make_env(  # pylint: disable=no-value-for-parameter
         observation_space=gym.spaces.Box(shape=(2,), low=0, high=1),
         action_space=gym.spaces.Discrete(n=1),
         max_trajectory_length=2,
         batch_size=3,
     )
     trajectories = [
         self._make_trajectory(observations, actions)  # pylint: disable=g-complex-comprehension
         for (observations, actions) in [
             (np.array([[0, 1]]), np.array([0])),
             (np.array([[1, 2], [3, 4]]), np.array([0, 0])),
             (np.array([[1, 2], [3, 4], [5, 6]]), np.array([0, 0, 0])),
         ]
     ]
     metrics = simple.evaluate_model(env, trajectories, plt)
     self.assertIsNotNone(metrics)
     self.assertEqual(len(metrics), 2)
Esempio n. 11
0
    def _test_train_eval_predict(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == 'tf':
            self.skipTest("tf-numpy backend does't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2

            # Adds Dropout and BatchNorm to test state handling.
            def model_fn(mode='train'):
                return layers.Serial(
                    layers.Dropout(mode=mode, rate=0.1),
                    layers.BatchNorm(mode=mode),
                    models.MLP(d_hidden=16,
                               n_output_classes=n_classes,
                               mode=mode))

            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            state = trainer_lib.train(output_dir,
                                      model=model_fn,
                                      inputs=inputs,
                                      train_steps=train_steps,
                                      eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(train_steps, state.step)

            # Assert 2 evaluations ran
            train_acc = state.history.get('train', 'metrics/accuracy')
            eval_acc = state.history.get('eval', 'metrics/accuracy')
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertLen(eval_acc, 2)

            # Predict with final params
            inputs = inputs(1).train_stream()
            model = layers.Serial(model_fn())
            model(next(inputs)[0], params=state.opt_state.params)
Esempio n. 12
0
  def test_reset_twice(self, backend_name):
    if xla_bridge.device_count() > 1 and backend_name == 'tf':
      self.skipTest("tf-numpy backend doesn't support multi-devices yet.")
    with backend.use_backend(backend_name), self.tmp_dir() as output_dir1, \
          self.tmp_dir() as output_dir2:
      n_classes = 4
      model_fn = functools.partial(
          models.MLP, d_hidden=16, n_output_classes=n_classes)
      inputs = lambda _: test_inputs(n_classes)

      trainer = trainer_lib.Trainer(
          model=model_fn,
          loss_fn=layers.CrossEntropyLossScalar,
          optimizer=trax_opt.SM3,
          lr_schedule=lr.MultifactorSchedule,
          inputs=inputs,
      )

      trainer.reset(output_dir1)
      trainer.evaluate(1)
      trainer.reset(output_dir2)
      trainer.evaluate(1)
    def test_communicates_with_model(self, mock_restore_state):
        gin.bind_parameter('BoxSpaceSerializer.precision', 1)
        vocab_size = 16
        # Mock model predicting a fixed sequence of symbols. It is made such that
        # the first two observations are different and the last one is equal to the
        # first.
        symbols = [
            1,
            1,
            2,
            2,
            0,
            0,  # obs1 act1
            1,
            2,
            2,
            1,
            0,
            0,  # obs2 act2
            1,
            1,
            2,
            2,  # obs3
        ]

        def make_prediction(symbol):
            one_hot = np.eye(vocab_size)[symbol]
            log_probs = (1 - one_hot) * -100.0  # Virtually deterministic.
            # (4 obs symbols + 1 action symbol) * 3 timesteps = 15.
            return np.array([[log_probs]])

        mock_predict_fn = mock.MagicMock()
        mock_predict_fn.side_effect = map(make_prediction, symbols)

        with backend.use_backend('numpy'):
            # (model_params, opt_state)
            mock_restore_state.return_value.params = (None, None)
            env = self._make_env(
                predict_fn=mock_predict_fn,
                reward_fn=(lambda _1, _2: np.array([0.5])),
                done_fn=(lambda _1, _2: np.array([False])),
                vocab_size=vocab_size,
                batch_size=1,
                max_trajectory_length=3,
                observation_space=gym.spaces.Box(low=0, high=5, shape=(4, )),
                action_space=gym.spaces.MultiDiscrete(nvec=[2, 2]),
            )

            def assert_input_suffix(expected_symbols):
                actual_symbols = np.array([
                    symbol.item() for ((symbol, ), _) in
                    mock_predict_fn.call_args_list[-len(expected_symbols):]
                ])
                np.testing.assert_array_equal(actual_symbols, expected_symbols)

            actions = [[0, 1], [1, 0]]

            obs1 = env.reset()
            assert_input_suffix(symbols[:3])

            (obs2, reward, done, _) = env.step(np.array([actions[0]]))
            # Symbols going into the decoder when predicting the next observation are:
            # the last symbol of the previous observation, all action symbols, all
            # symbols but the last one of the next observation.
            assert_input_suffix([symbols[3]] + actions[0] + symbols[6:9])
            self.assertFalse(np.array_equal(obs1, obs2))
            np.testing.assert_array_equal(reward, [0.5])
            np.testing.assert_array_equal(done, [False])

            (obs3, reward, done, _) = env.step(np.array([actions[1]]))
            assert_input_suffix([symbols[9]] + actions[1] + symbols[12:15])
            np.testing.assert_array_equal(obs1, obs3)
            np.testing.assert_array_equal(reward, [0.5])
            np.testing.assert_array_equal(done, [True])