def test_computes_basic_mean(self): inputs = [np.array([1, 2, 3])] targets = [np.zeros(3)] weights = [1] with backend.use_backend("numpy"): mean = trax.masked_mean(inputs, targets, weights) np.testing.assert_allclose(mean, 2)
def test_computes_mean_with_weights(self, backend_name): with backend.use_backend(backend_name): inputs = [np.array([1, 2, 3])] targets = [np.zeros(3)] weights = [np.array([3, 1, 0])] mean = trax.masked_mean(inputs, targets, weights) onp.testing.assert_allclose(mean, 1.25)
def test_train_with_weights(self, backend_name): if jax.lib.xla_bridge.device_count() > 1 and backend_name == "tf": self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: gin.bind_parameter("unpack_batch.has_weights", True) # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes, with_weights=True) # Train and evaluate state = trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps) # Assert total train steps self.assertEqual(state.step, train_steps)
def test_computes_mean_with_mask(self): inputs = [np.array([1, 2, 3])] targets = [np.array([1, 0, 0])] weights = [1] with backend.use_backend("numpy"): mean = trax.masked_mean(inputs, targets, weights, mask_id=1) np.testing.assert_allclose(mean, 2.5)
def test_reformer_rng_consistency(self): with backend.use_backend('jax'): vocab_size = 16 batch_size = 1 input_shape = ((batch_size, 8), (batch_size, 8)) model = reformer.ReformerLM( vocab_size, d_model=32, d_ff=64, d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, max_len=16, n_chunks=2, n_attention_chunks=1, mode='train', attention_type=PoisonOnRNGMismatchAttention) rng = backend.random.get_prng(0) params, state = model.initialize_once(input_shape, (np.int32, np.int32), rng) def dummy_loss_fn(params): inputs = (np.zeros(input_shape[0], dtype=np.int32), ) * 2 output = model(inputs, params=params, state=state, rng=rng) dummy_loss = backend.numpy.sum(output[0]) return dummy_loss grad_fn = backend.grad(dummy_loss_fn) grads = grad_fn(params) # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch. for grad in jax.tree_util.tree_leaves(grads): assert onp.all(onp.isfinite(grad))
def test_train_eval_predict_sm3(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == "tf": self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) # Train and evaluate state = trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps, optimizer=trax_opt.SM3) # Assert total train steps self.assertEqual(train_steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get("train", "metrics/accuracy") eval_acc = state.history.get("eval", "metrics/accuracy") self.assertEqual(len(train_acc), len(eval_acc)) self.assertLen(eval_acc, 2) # Predict with final params inputs = inputs(1).train_stream() model = layers.Serial(model_fn()) model(next(inputs)[0], params=state.opt_state.params)
def test_train_restart(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == "tf": self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) # Train and evaluate trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps) # Restart training state = trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=(2 * train_steps), eval_steps=eval_steps) # Assert total train steps self.assertEqual(state.step, 2 * train_steps)
def pseudo_call(self, pseudo_inputs, params): """Computes shapes and types this layer would produce for the given inputs. Args: pseudo_inputs: A ShapeType instance (input data minus the actual values) or a tuple of ShapeType instances, following the same conventions as Layer.call's input arg. params: Parameters for this layer. Returns: A ShapeType instance representing the shape and type of the output (if this layer has one output) or a tuple of ShapeType instances (if this layer has more than one output). """ try: with backend.use_backend('jax'): # Beware: using an actual RNG (as opposed to this ShapeType stub) would # cause a large number of dropout masks to be computed and permanently # stored in global memory. rng = ShapeType(shape=(2, ), dtype=onp.uint32) def call_on_input(x, params, rng): return self.call(x, params=params, rng=rng) params_shapes = nested_map( params, lambda x: ShapeType(shape=x.shape, dtype=x.dtype)) s = _eval_on_shapes(call_on_input, pseudo_inputs, params_shapes, rng) return s except Exception: name, trace = self.__class__.__name__, _short_traceback(skip=3) raise LayerError(name, 'pseudo_call', self._caller, pseudo_inputs, trace)
def output_shape(self, input_shape_and_type, params): """Output shape and type for this layer given input shape and type. Note that all arguments and return values can be tuples or dictionaries or arbitrary nested structures composed of tuples and dictionaries. Args: input_shape_and_type: a ShapeType with shape and type of the input. params: parameters for this layer. Returns: The shape and type of the output. """ try: with backend.use_backend('jax'): rng = backend.random.get_prng(0) def call_on_input(x, params): f = lambda y: self.call(y, params=params, rng=rng) n = self.stack_items_to_pass() if isinstance(x, (list, tuple)) else 0 return _apply_to_first_n(f, x, n) params_shapes = nested_map( params, lambda x: ShapeType(shape=x.shape, tp=x.dtype)) s = _eval_on_shapes(call_on_input, input_shape_and_type, params_shapes) return s except Exception: name, trace = self.__class__.__name__, _short_traceback(skip=3) raise LayerError(name, 'output_shape', self._caller, input_shape_and_type, trace)
def test_transformer_lm_fast_inference(self): with backend.use_backend('jax'): vocab_size = 16 model_fn = functools.partial(transformer.TransformerLM, vocab_size=vocab_size, d_model=4, d_ff=8, n_layers=2, n_heads=2) model_slow = model_fn(mode='eval') model_fast = model_fn(mode='predict') rng = backend.random.get_prng(0) batch_size = 2 _, _ = model_slow.initialize_once((batch_size, 1), np.int32, rng) _, _ = model_fast.initialize_once((batch_size, 1), np.int32, rng) max_length = 5 buf = onp.zeros((batch_size, max_length), dtype=np.int32) next_sym = onp.zeros((batch_size, 1), dtype=onp.int32) for index in range(max_length): logits_slow = model_slow(buf, rng=rng) logits_fast = model_fast(next_sym, rng=rng) onp.testing.assert_array_almost_equal(logits_slow[:, index, :], logits_fast[:, 0, :]) next_sym = onp.random.randint(vocab_size, size=(batch_size, 1)) buf[:, index] = next_sym[:, 0]
def test_computes_mean_with_weights_and_mask(self, backend_name): with backend.use_backend(backend_name): inputs = [np.array([1, 2, 4])] targets = [np.array([1, 0, 0])] weights = [np.array([10, 4, 1])] mean = trax.masked_mean(inputs, targets, weights, mask_id=1) onp.testing.assert_allclose(mean, 2.4)
def _test_fast_inference(self, attention_type, length): with backend.use_backend('jax'): vocab_size = 16 model_fn = functools.partial( transformer.TransformerLM, vocab_size=vocab_size, d_model=4, d_ff=8, n_layers=2, n_heads=2, attention_type=attention_type, ) model_slow = model_fn(mode='eval') model_fast = model_fn(mode='predict') rng = backend.random.get_prng(0) batch_size = 2 # Given the same rng, both models initialize with the same parameters. model_slow.initialize_once((batch_size, 1), np.int32, rng) model_fast.initialize_once((batch_size, 1), np.int32, rng) buf = onp.zeros((batch_size, length), dtype=np.int32) next_sym = onp.zeros((batch_size, 1), dtype=onp.int32) for index in range(length): logits_slow = model_slow(buf, rng=rng) logits_fast = model_fast(next_sym, rng=rng) onp.testing.assert_array_almost_equal(logits_slow[:, index, :], logits_fast[:, 0, :]) next_sym = onp.random.randint(vocab_size, size=(batch_size, 1)) buf[:, index] = next_sym[:, 0]
def pseudo_call(self, pseudo_input, params): """Computes what shapes and types this layer would produce for given input. Args: pseudo_input: A ShapeType instance (input data minus the actual values) or a tuple of ShapeType instances. params: Parameters for this layer. Returns: A ShapeType instance representing the shape and type of the output (if this layer has one output) or a tuple of ShapeType instances (if this layer has more than one output). """ try: with backend.use_backend('jax'): rng = backend.random.get_prng(0) def call_on_input(x, params): f = lambda y: self.call(y, params=params, rng=rng) n = self.stack_items_to_pass() if isinstance( x, (list, tuple)) else 0 return _apply_to_first_n(f, x, n) params_shapes = nested_map( params, lambda x: ShapeType(shape=x.shape, dtype=x.dtype)) s = _eval_on_shapes(call_on_input, pseudo_input, params_shapes) return s except Exception: name, trace = self.__class__.__name__, _short_traceback(skip=3) raise LayerError(name, 'pseudo_call', self._caller, pseudo_input, trace)
def pseudo_call(self, pseudo_inputs, params): """Computes shapes and types this layer would produce for the given inputs. Args: pseudo_inputs: A ShapeType instance (input data minus the actual values) or a tuple of ShapeType instances, following the same conventions as Layer.call's input arg. params: Parameters for this layer. Returns: A ShapeType instance representing the shape and type of the output (if this layer has one output) or a tuple of ShapeType instances (if this layer has more than one output). """ try: with backend.use_backend('jax'): # Same as backend.random.get_prng(0), but no op-by-op execution. rng = onp.zeros(2, onp.uint32) def call_on_input(x, params): return self.call(x, params=params, rng=rng) params_shapes = nested_map( params, lambda x: ShapeType(shape=x.shape, dtype=x.dtype)) s = _eval_on_shapes(call_on_input, pseudo_inputs, params_shapes) return s except Exception: name, trace = self.__class__.__name__, _short_traceback(skip=3) raise LayerError(name, 'pseudo_call', self._caller, pseudo_inputs, trace)
def test_communicates_with_model(self, mock_restore_state): gin.bind_parameter("BoxSpaceSerializer.precision", 1) vocab_size = 16 # Mock model predicting a fixed sequence of symbols. It is made such that # the first two observations are equal and the last one is different. symbols = [ 1, 1, 2, 2, # obs1 1, 1, 2, 2, # obs2 1, 2, 2, 1, # obs3 ] def make_prediction(symbol): one_hot = np.eye(vocab_size)[symbol] log_probs = (1 - one_hot) * -100.0 # Virtually deterministic. # (4 obs symbols + 1 action symbol) * 3 timesteps = 15. return np.array([[log_probs] * 15]) mock_model_fn = mock.MagicMock() mock_model = mock_model_fn.return_value mock_model.side_effect = map(make_prediction, symbols) with backend.use_backend("numpy"): # (model_params, opt_state) mock_restore_state.return_value.params = (None, None) env = simulated_env_problem.SerializedSequenceSimulatedEnvProblem( model=mock_model_fn, reward_fn=(lambda _1, _2: np.array([0.5])), done_fn=(lambda _1, _2: np.array([False])), vocab_size=vocab_size, max_trajectory_length=3, batch_size=1, observation_space=gym.spaces.Box(low=0, high=5, shape=(4,)), action_space=gym.spaces.Discrete(2), reward_range=(-1, 1), discrete_rewards=False, history_stream=itertools.repeat(None), output_dir=None, ) obs1 = env.reset() ((inputs,), _) = mock_model.call_args act1 = 0 (obs2, reward, done, _) = env.step(np.array([act1])) ((inputs,), _) = mock_model.call_args self.assertEqual(inputs[0, 4], act1) np.testing.assert_array_equal(inputs[0, :4], symbols[:4]) np.testing.assert_array_equal(obs1, obs2) np.testing.assert_array_equal(reward, [0.5]) np.testing.assert_array_equal(done, [False]) act2 = 1 (obs3, reward, done, _) = env.step(np.array([act2])) ((inputs,), _) = mock_model.call_args self.assertEqual(inputs[0, 9], act2) np.testing.assert_array_equal(inputs[0, 5:9], symbols[4:8]) self.assertFalse(np.array_equal(obs2, obs3)) np.testing.assert_array_equal(reward, [0.5]) np.testing.assert_array_equal(done, [False])
def test_takes_new_history(self): histories = np.array([[[0, 1, 2]], [[3, 4, 5]]]) with backend.use_backend("numpy"): env = self._create_env( # pylint: disable=no-value-for-parameter model=mock.MagicMock(), histories=histories, trajectory_length=2, ) env.reset() observation = env.reset() np.testing.assert_array_equal(observation, [5])
def train_epoch(self, epoch_steps, eval_steps): """Train for one epoch.""" # Log separator print() # Timer start_time = time.time() for _ in range(epoch_steps): # Train next_train_batch = next(self._train_stream) if self._n_devices > 1: # TODO(lukaszkaiser): use everywhere if possible. next_train_batch = reshape_by_device(next_train_batch, self._n_devices) self._opt_state, self._rngs = self._jit_update_fn( self._step, self._opt_state, next_train_batch, self._rngs) self._step += 1 if self._step in self._save_steps: _save_replicated(self._opt_state, self._step, self._history, self._n_devices, self._output_dir, True) # LR log if self._step == 1 or self._step % 10 == 0: # TODO(lukaszkaiser): it makes no sense to use an accelerator (e.g. TPU) # in op-by-op mode just to compute the learning rate. However, there # should be a cleaner approach that forceably swapping out the backend. with backend.use_backend("numpy"): self._train_sw.scalar("training/learning rate", self._lr_fn(self._step), step=self._step) # Timer epoch_time = time.time() - start_time step_log( self._step, "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) if epoch_steps > 1: self._train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=self._step) # Evaluate in parallel self.evaluate(eval_steps) # Save state _save_replicated(self._opt_state, self._step, self._history, self._n_devices, self._output_dir, False) # Flush summary writers self._train_sw.flush() self._eval_sw.flush()
def test_communicates_with_model(self): # Mock model increasing the observation by action, reward is the parity of # the new observation. def mock_transition(inputs, *args, **kwargs): del args del kwargs (observations, actions) = inputs new_observations = observations[:, -1] + actions rewards = np.array([[int(new_observations % 2 == 0)]]) return (new_observations, rewards) mock_model_fn = mock.MagicMock() mock_model_fn.return_value.side_effect = mock_transition mock_model = mock_model_fn.return_value actions_to_take = np.array([[1], [3]]) initial_observations = np.array([[[0, 1, 2, 3]]]) expected_observations = np.array([[3], [4], [7]]) expected_rewards = np.array([[1], [0]]) expected_dones = np.array([[False], [True]]) expected_histories = np.array([[[0, 1, 2, 3]], [[1, 2, 3, 4]]]) expected_actions = actions_to_take with backend.use_backend("numpy"): env = self._create_env( # pylint: disable=no-value-for-parameter model=mock_model_fn, initial_observations=initial_observations, trajectory_length=len(actions_to_take), ) actual_observations = [env.reset()] actual_rewards = [] actual_dones = [] actual_histories = [] actual_actions = [] for action in actions_to_take: (observation, reward, done, _) = env.step(action) actual_observations.append(observation) actual_rewards.append(reward) actual_dones.append(done) # Mock call is a tuple (args, kwargs). There is one positional argument, # which is a tuple (history, action). (((history, action), ), _) = mock_model.call_args actual_actions.append(action) actual_histories.append(history) np.testing.assert_array_equal(actual_observations, expected_observations) np.testing.assert_array_equal(actual_rewards, expected_rewards) np.testing.assert_array_equal(actual_dones, expected_dones) np.testing.assert_array_equal(actual_histories, expected_histories) np.testing.assert_array_equal(actual_actions, expected_actions)
def test_fails_to_evaluate_model_with_matrix_observation_space(self): with backend.use_backend("numpy"): env = self._make_env( # pylint: disable=no-value-for-parameter observation_space=gym.spaces.Box(shape=(2, 2), low=0, high=1), action_space=gym.spaces.Discrete(n=1), max_trajectory_length=2, batch_size=1, ) trajectories = [ self._make_trajectory(np.array([[0, 1], [2, 3]]), np.array([0])) ] metrics = simple.evaluate_model(env, trajectories, plt) self.assertIsNone(metrics)
def test_evaluates_model_with_vector_observation_space(self): with backend.use_backend("numpy"): env = self._make_env( # pylint: disable=no-value-for-parameter observation_space=gym.spaces.Box(shape=(2, ), low=0, high=1), action_space=gym.spaces.Discrete(n=1), max_trajectory_length=2, batch_size=3, ) trajectories = [ self._make_trajectory(observations, actions) # pylint: disable=g-complex-comprehension for (observations, actions) in [ (np.array([[0, 1]]), np.array([0])), (np.array([[1, 2], [3, 4]]), np.array([0, 0])), (np.array([[1, 2], [3, 4], [5, 6]]), np.array([0, 0, 0])), ] ] metrics = simple.evaluate_model(env, trajectories, plt) self.assertIsNotNone(metrics) self.assertEqual(len(metrics), 2)
def test_train_eval_predict(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == "tf": self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 # Adds Dropout and BatchNorm to test state handling. def model_fn(mode="train"): return layers.Model( layers.Dropout(mode=mode, rate=0.1), layers.BatchNorm(mode=mode), models.MLP(d_hidden=16, n_output_classes=n_classes, mode=mode)) inputs = lambda _: test_inputs(n_classes) # Train and evaluate state = trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps) # Assert total train steps self.assertEqual(train_steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get("train", "metrics/accuracy") eval_acc = state.history.get("eval", "metrics/accuracy") self.assertEqual(len(train_acc), len(eval_acc)) self.assertLen(eval_acc, 2) # Predict with final params inputs = inputs(1).train_stream() model = layers.Serial(model_fn()) model(next(inputs)[0], params=state.opt_state.params)
def test_reset_twice(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == "tf": self.skipTest("tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir1, \ self.tmp_dir() as output_dir2: n_classes = 4 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) trainer = trax.Trainer( model=model_fn, loss_fn=trax.loss, optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, inputs=inputs, ) trainer.reset(output_dir1) trainer.evaluate(1) trainer.reset(output_dir2) trainer.evaluate(1)
def optimizer_params(self): # TODO(lukaszkaiser): it makes no sense to use an accelerator (e.g. TPU) # in op-by-op mode just to compute the learning rate. However, there # should be a cleaner approach that forceably swapping out the backend. with backend.use_backend("numpy"): return self._lr_fn(self._step)
def test_communicates_with_model(self, mock_restore_state): gin.bind_parameter("BoxSpaceSerializer.precision", 1) vocab_size = 16 # Mock model predicting a fixed sequence of symbols. It is made such that # the first two observations are different and the last one is equal to the # first. symbols = [ 1, 1, 2, 2, 0, 0, # obs1 act1 1, 2, 2, 1, 0, 0, # obs2 act2 1, 1, 2, 2, # obs3 ] def make_prediction(symbol): one_hot = np.eye(vocab_size)[symbol] log_probs = (1 - one_hot) * -100.0 # Virtually deterministic. # (4 obs symbols + 1 action symbol) * 3 timesteps = 15. return np.array([[log_probs]]) mock_predict_fn = mock.MagicMock() mock_predict_fn.side_effect = map(make_prediction, symbols) with backend.use_backend("numpy"): # (model_params, opt_state) mock_restore_state.return_value.params = (None, None) env = self._make_env( predict_fn=mock_predict_fn, reward_fn=(lambda _1, _2: np.array([0.5])), done_fn=(lambda _1, _2: np.array([False])), vocab_size=vocab_size, batch_size=1, max_trajectory_length=3, observation_space=gym.spaces.Box(low=0, high=5, shape=(4,)), action_space=gym.spaces.MultiDiscrete(nvec=[2, 2]), ) def assert_input_suffix(expected_symbols): actual_symbols = np.array([ symbol.item() for ((symbol,), _) in mock_predict_fn.call_args_list[ -len(expected_symbols): ] ]) np.testing.assert_array_equal(actual_symbols, expected_symbols) actions = [[0, 1], [1, 0]] obs1 = env.reset() assert_input_suffix(symbols[:3]) (obs2, reward, done, _) = env.step(np.array([actions[0]])) # Symbols going into the decoder when predicting the next observation are: # the last symbol of the previous observation, all action symbols, all # symbols but the last one of the next observation. assert_input_suffix([symbols[3]] + actions[0] + symbols[6:9]) self.assertFalse(np.array_equal(obs1, obs2)) np.testing.assert_array_equal(reward, [0.5]) np.testing.assert_array_equal(done, [False]) (obs3, reward, done, _) = env.step(np.array([actions[1]])) assert_input_suffix([symbols[9]] + actions[1] + symbols[12:15]) np.testing.assert_array_equal(obs1, obs3) np.testing.assert_array_equal(reward, [0.5]) np.testing.assert_array_equal(done, [True])