def test_autoregressive_sample_transformerlm(self): model = models.TransformerLM(10, d_model=32, d_ff=64, n_layers=1, n_heads=2, mode='predict') model.init(shapes.ShapeDtype((1, 1), dtype=jnp.int32)) s1 = trainer_lib.autoregressive_sample(model, batch_size=1, eos_id=-1, max_length=10) self.assertEqual(s1.shape[0], 1) self.assertEqual(s1.shape[1], 10) batch_per_device = 2 // fastmath.device_count() model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=jnp.int32)) s2 = trainer_lib.autoregressive_sample(model, batch_size=2, max_length=10) self.assertEqual(s2.shape[0], 2) self.assertLess(s2.shape[1], 11) model.init(shapes.ShapeDtype((1, 1), dtype=jnp.int32)) prefix = jnp.array([[1, 2, 3]]) s3 = trainer_lib.autoregressive_sample(model, eos_id=-1, max_length=10, batch_size=1, prefix=prefix) self.assertEqual(s3.shape[0], 1) self.assertEqual(int(s3[0][0]), 1) self.assertEqual(int(s3[0][1]), 2) self.assertEqual(int(s3[0][2]), 3)
def test_autoregressive_sample_transformerlm_tfnp(self): with fastmath.use_backend(fastmath.Backend.TFNP): model = models.TransformerLM(10, d_model=32, d_ff=64, n_layers=1, n_heads=2, mode='predict') model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) s1 = decoding.autoregressive_sample(model, batch_size=1, eos_id=-1, max_length=10) self.assertEqual(s1.shape[0], 1) self.assertEqual(s1.shape[1], 10) batch_per_device = 2 // fastmath.device_count() model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=np.int32)) s2 = decoding.autoregressive_sample(model, batch_size=2, max_length=10) self.assertEqual(s2.shape[0], 2) self.assertLess(s2.shape[1], 11) model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) prefix = np.array([[1, 2, 3]]) s3 = decoding.autoregressive_sample(model, prefix, eos_id=-1, max_length=10, batch_size=1) self.assertEqual(s3.shape[0], 1) self.assertEqual(s3.shape[1], 10)
def test_training_loop_cartpole_serialized_init_from_world_model( self, two_towers): gin.bind_parameter('BoxSpaceSerializer.precision', 1) transformer_kwargs = { 'd_model': 1, 'd_ff': 1, 'n_layers': 1, 'n_heads': 1, 'max_len': 128, } obs_serializer = space_serializer.create(gym.spaces.MultiDiscrete( [2, 2]), vocab_size=4) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=4) model_fn = lambda mode: serialization_utils.SerializedModel( # pylint: disable=g-long-lambda seq_model=models.TransformerLM( mode=mode, vocab_size=4, **transformer_kwargs), observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) with self.tmp_dir() as output_dir: model_dir = os.path.join(output_dir, 'model') def dummy_stream(_): while True: obs = np.zeros((1, 2, 2), dtype=np.int32) act = np.zeros((1, 1), dtype=np.int32) mask = np.ones_like(obs) yield (obs, act, obs, mask) inputs = trax_inputs.Inputs(train_stream=dummy_stream, eval_stream=dummy_stream) inputs._input_shape = ((2, 2), (1, )) # pylint: disable=protected-access inputs._input_dtype = (np.int32, np.int32) # pylint: disable=protected-access # Initialize a world model checkpoint by running the trainer. trainer_lib.train( model_dir, model=model_fn, inputs=inputs, steps=1, eval_steps=1, has_weights=True, ) policy_dir = os.path.join(output_dir, 'policy') trainer = self._make_trainer( train_env=self.get_wrapped_env('CartPole-v0', 2), eval_env=self.get_wrapped_env('CartPole-v0', 2), output_dir=policy_dir, model=functools.partial(models.TransformerDecoder, **transformer_kwargs), policy_and_value_vocab_size=4, init_policy_from_world_model_output_dir=model_dir, policy_and_value_two_towers=two_towers, ) trainer.training_loop(n_epochs=2)
def inner_model(mode): return models.TransformerLM( mode=mode, vocab_size=vocab_size, d_model=2, d_ff=4, n_layers=1, n_heads=1, )
def test_autoregressive_sample_transformerlm_quality(self): pred_model = models.TransformerLM( d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2, n_layers=2, vocab_size=13, mode='predict') shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) model_path = os.path.join(_TESTDATA, 'transformerlm_copy.pkl.gz') pred_model.init_from_file(model_path, weights_only=True, input_signature=(shape11, shape11)) inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) s = decoding.autoregressive_sample(pred_model, inputs, max_length=6, temperature=0.0) self.assertEqual(str(s[0]), '[3 7 5 3 2 4]')
def model(mode): return serialization_utils.SerializedModel( trax_models.TransformerLM(mode=mode, vocab_size=vocab_size, d_model=16, d_ff=8, n_layers=1, n_heads=1), observation_serializer=srl, action_serializer=srl, significance_decay=0.9, )
def test_autoregressive_sample_transformerlm_quality_eval(self): eval_model = models.TransformerLM(d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2, n_layers=2, vocab_size=13, mode='eval') model_path = os.path.join(_TESTDATA, 'transformerlm_copy.pkl.gz') eval_model.init_from_file(model_path) inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) s = decoding.autoregressive_sample(eval_model, inputs, eval_mode=True, max_length=6, temperature=0.0) self.assertEqual(str(s[0]), '[3 7 5 3 2 4]')