def test_train_restart(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == 'tf': self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) # Train and evaluate trainer_lib.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps) # Restart training state = trainer_lib.train(output_dir, model=model_fn, inputs=inputs, train_steps=(2 * train_steps), eval_steps=eval_steps) # Assert total train steps self.assertEqual(state.step, 2 * train_steps)
def _test_fast_inference(self, attention_type, length): with backend.use_backend('jax'): vocab_size = 16 model_fn = functools.partial( transformer.TransformerLM, vocab_size=vocab_size, d_model=4, d_ff=8, n_layers=2, n_heads=2, attention_type=attention_type, ) model_slow = model_fn(mode='eval') model_fast = model_fn(mode='predict') rng = backend.random.get_prng(0) batch_size = 2 # Given the same rng, both models initialize with the same parameters. model_slow.initialize_once((batch_size, 1), np.int32, rng) model_fast.initialize_once((batch_size, 1), np.int32, rng) buf = onp.zeros((batch_size, length), dtype=np.int32) next_sym = onp.zeros((batch_size, 1), dtype=onp.int32) for index in range(length): logits_slow = model_slow(buf, rng=rng) logits_fast = model_fast(next_sym, rng=rng) onp.testing.assert_array_almost_equal(logits_slow[:, index, :], logits_fast[:, 0, :]) next_sym = onp.random.randint(vocab_size, size=(batch_size, 1)) buf[:, index] = next_sym[:, 0]
def nontrainable_params(self): # TODO(afrozm): Give further thought to this name. # TODO(lukaszkaiser): it makes no sense to use an accelerator (e.g. TPU) # in op-by-op mode just to compute the learning rate. However, there # should be a cleaner approach that forceably swapping out the backend. with backend.use_backend('numpy'): return self._lr_fn(self._step)
def test_train_eval_predict_sm3(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == 'tf': self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) # Train and evaluate state = trainer_lib.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps, optimizer=trax_opt.SM3) # Assert total train steps self.assertEqual(train_steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get('train', 'metrics/accuracy') eval_acc = state.history.get('eval', 'metrics/accuracy') self.assertEqual(len(train_acc), len(eval_acc)) self.assertLen(eval_acc, 2) # Predict with final params inputs = inputs(1).train_stream() model = layers.Serial(model_fn()) model(next(inputs)[0], params=state.opt_state.params)
def test_reformer_rng_consistency(self): with backend.use_backend('jax'): vocab_size = 16 batch_size = 1 input_sd = ShapeDtype((batch_size, 8), np.int32) input_signature = (input_sd, input_sd) model = reformer.ReformerLM( vocab_size, d_model=32, d_ff=64, d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, max_len=16, n_chunks=2, n_attention_chunks=1, mode='train', attention_type=PoisonOnRNGMismatchAttention) rng = backend.random.get_prng(0) weights, state = model.initialize_once(input_signature) def dummy_loss_fn(weights): inputs = (np.zeros(input_sd.shape, dtype=np.int32), ) * 2 output = model(inputs, weights=weights, state=state, rng=rng) dummy_loss = backend.numpy.sum(output[0]) return dummy_loss grad_fn = backend.grad(dummy_loss_fn) grads = grad_fn(weights) # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch. for grad in jax.tree_util.tree_leaves(grads): assert onp.all(onp.isfinite(grad))
def test_no_int32_or_uint32_returned(self): """Tests that Trainer._jit_update_fn doesn't return int32 or uint32. TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled computation to copy int32/uint32 outputs to CPU. This test makes sure that won't happen. """ if xla_bridge.device_count() > 1: self.skipTest("tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend('tf'), self.tmp_dir() as output_dir: n_classes = 1001 model_fn = functools.partial(models.Resnet50, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes, input_shape=(224, 224, 3)) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=layers.CrossEntropyLossScalar, optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, inputs=inputs, ) trainer.reset(output_dir) trainer.train_epoch(1, 0) # Those are the things returned by Trainer._jit_update_fn arrays = (trainer._opt_state.weights, trainer._opt_state.slots, trainer._model_state, trainer._rngs) arrays = tf.nest.flatten(arrays) for x in arrays: if isinstance(x, np.ndarray) and (x.dtype == np.int32 or x.dtype == np.uint32): raise ValueError('Found an array of int32 or uint32: %s' % x)
def test_fails_to_evaluate_model_with_matrix_observation_space(self): with backend.use_backend('numpy'): env = self._make_env( # pylint: disable=no-value-for-parameter observation_space=gym.spaces.Box(shape=(2, 2), low=0, high=1), action_space=gym.spaces.Discrete(n=1), max_trajectory_length=2, batch_size=1, ) trajectories = [ self._make_trajectory(np.array([[0, 1], [2, 3]]), np.array([0]))] metrics = simple.evaluate_model(env, trajectories, plt) self.assertIsNone(metrics)
def test_takes_new_history(self): histories = np.array([[[0, 1, 2]], [[3, 4, 5]]]) with backend.use_backend('numpy'): env = self._create_env( # pylint: disable=no-value-for-parameter model=mock.MagicMock(), histories=histories, trajectory_length=2, ) env.reset() observation = env.reset() np.testing.assert_array_equal(observation, [5])
def test_communicates_with_model(self): # Mock model increasing the observation by action, reward is the parity of # the new observation. def mock_transition(inputs, *args, **kwargs): del args del kwargs (observations, actions) = inputs new_observations = observations[:, -1] + actions rewards = np.array([[int(new_observations % 2 == 0)]]) return (new_observations, rewards) mock_model_fn = mock.MagicMock() mock_model_fn.return_value.side_effect = mock_transition mock_model = mock_model_fn.return_value actions_to_take = np.array([[1], [3]]) histories = np.array([[[0, 1, 2, 3]]]) expected_observations = np.array([[3], [4], [7]]) expected_rewards = np.array([[1], [0]]) expected_dones = np.array([[False], [True]]) expected_histories = np.array([[[0, 1, 2, 3]], [[1, 2, 3, 4]]]) expected_actions = actions_to_take with backend.use_backend('numpy'): env = self._create_env( # pylint: disable=no-value-for-parameter model=mock_model_fn, histories=histories, trajectory_length=len(actions_to_take), ) actual_observations = [env.reset()] actual_rewards = [] actual_dones = [] actual_histories = [] actual_actions = [] for action in actions_to_take: (observation, reward, done, _) = env.step(action) actual_observations.append(observation) actual_rewards.append(reward) actual_dones.append(done) # Mock call is a tuple (args, kwargs). There is one positional argument, # which is a tuple (history, action). (((history, action), ), _) = mock_model.call_args actual_actions.append(action) actual_histories.append(history) np.testing.assert_array_equal(actual_observations, expected_observations) np.testing.assert_array_equal(actual_rewards, expected_rewards) np.testing.assert_array_equal(actual_dones, expected_dones) np.testing.assert_array_equal(actual_histories, expected_histories) np.testing.assert_array_equal(actual_actions, expected_actions)
def test_evaluates_model_with_vector_observation_space(self): with backend.use_backend('numpy'): env = self._make_env( # pylint: disable=no-value-for-parameter observation_space=gym.spaces.Box(shape=(2,), low=0, high=1), action_space=gym.spaces.Discrete(n=1), max_trajectory_length=2, batch_size=3, ) trajectories = [ self._make_trajectory(observations, actions) # pylint: disable=g-complex-comprehension for (observations, actions) in [ (np.array([[0, 1]]), np.array([0])), (np.array([[1, 2], [3, 4]]), np.array([0, 0])), (np.array([[1, 2], [3, 4], [5, 6]]), np.array([0, 0, 0])), ] ] metrics = simple.evaluate_model(env, trajectories, plt) self.assertIsNotNone(metrics) self.assertEqual(len(metrics), 2)
def _test_train_eval_predict(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == 'tf': self.skipTest("tf-numpy backend does't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 # Adds Dropout and BatchNorm to test state handling. def model_fn(mode='train'): return layers.Serial( layers.Dropout(mode=mode, rate=0.1), layers.BatchNorm(mode=mode), models.MLP(d_hidden=16, n_output_classes=n_classes, mode=mode)) inputs = lambda _: test_inputs(n_classes) # Train and evaluate state = trainer_lib.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps) # Assert total train steps self.assertEqual(train_steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get('train', 'metrics/accuracy') eval_acc = state.history.get('eval', 'metrics/accuracy') self.assertEqual(len(train_acc), len(eval_acc)) self.assertLen(eval_acc, 2) # Predict with final params inputs = inputs(1).train_stream() model = layers.Serial(model_fn()) model(next(inputs)[0], params=state.opt_state.params)
def test_reset_twice(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == 'tf': self.skipTest("tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir1, \ self.tmp_dir() as output_dir2: n_classes = 4 model_fn = functools.partial( models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=layers.CrossEntropyLossScalar, optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, inputs=inputs, ) trainer.reset(output_dir1) trainer.evaluate(1) trainer.reset(output_dir2) trainer.evaluate(1)
def test_communicates_with_model(self, mock_restore_state): gin.bind_parameter('BoxSpaceSerializer.precision', 1) vocab_size = 16 # Mock model predicting a fixed sequence of symbols. It is made such that # the first two observations are different and the last one is equal to the # first. symbols = [ 1, 1, 2, 2, 0, 0, # obs1 act1 1, 2, 2, 1, 0, 0, # obs2 act2 1, 1, 2, 2, # obs3 ] def make_prediction(symbol): one_hot = np.eye(vocab_size)[symbol] log_probs = (1 - one_hot) * -100.0 # Virtually deterministic. # (4 obs symbols + 1 action symbol) * 3 timesteps = 15. return np.array([[log_probs]]) mock_predict_fn = mock.MagicMock() mock_predict_fn.side_effect = map(make_prediction, symbols) with backend.use_backend('numpy'): # (model_params, opt_state) mock_restore_state.return_value.params = (None, None) env = self._make_env( predict_fn=mock_predict_fn, reward_fn=(lambda _1, _2: np.array([0.5])), done_fn=(lambda _1, _2: np.array([False])), vocab_size=vocab_size, batch_size=1, max_trajectory_length=3, observation_space=gym.spaces.Box(low=0, high=5, shape=(4, )), action_space=gym.spaces.MultiDiscrete(nvec=[2, 2]), ) def assert_input_suffix(expected_symbols): actual_symbols = np.array([ symbol.item() for ((symbol, ), _) in mock_predict_fn.call_args_list[-len(expected_symbols):] ]) np.testing.assert_array_equal(actual_symbols, expected_symbols) actions = [[0, 1], [1, 0]] obs1 = env.reset() assert_input_suffix(symbols[:3]) (obs2, reward, done, _) = env.step(np.array([actions[0]])) # Symbols going into the decoder when predicting the next observation are: # the last symbol of the previous observation, all action symbols, all # symbols but the last one of the next observation. assert_input_suffix([symbols[3]] + actions[0] + symbols[6:9]) self.assertFalse(np.array_equal(obs1, obs2)) np.testing.assert_array_equal(reward, [0.5]) np.testing.assert_array_equal(done, [False]) (obs3, reward, done, _) = env.step(np.array([actions[1]])) assert_input_suffix([symbols[9]] + actions[1] + symbols[12:15]) np.testing.assert_array_equal(obs1, obs3) np.testing.assert_array_equal(reward, [0.5]) np.testing.assert_array_equal(done, [True])