def test_autoregressive_sample_transformerlm(self):
     model = models.TransformerLM(10,
                                  d_model=32,
                                  d_ff=64,
                                  n_layers=1,
                                  n_heads=2,
                                  mode='predict')
     model.init(shapes.ShapeDtype((1, 1), dtype=jnp.int32))
     s1 = trainer_lib.autoregressive_sample(model,
                                            batch_size=1,
                                            eos_id=-1,
                                            max_length=10)
     self.assertEqual(s1.shape[0], 1)
     self.assertEqual(s1.shape[1], 10)
     batch_per_device = 2 // fastmath.device_count()
     model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=jnp.int32))
     s2 = trainer_lib.autoregressive_sample(model,
                                            batch_size=2,
                                            max_length=10)
     self.assertEqual(s2.shape[0], 2)
     self.assertLess(s2.shape[1], 11)
     model.init(shapes.ShapeDtype((1, 1), dtype=jnp.int32))
     prefix = jnp.array([[1, 2, 3]])
     s3 = trainer_lib.autoregressive_sample(model,
                                            eos_id=-1,
                                            max_length=10,
                                            batch_size=1,
                                            prefix=prefix)
     self.assertEqual(s3.shape[0], 1)
     self.assertEqual(int(s3[0][0]), 1)
     self.assertEqual(int(s3[0][1]), 2)
     self.assertEqual(int(s3[0][2]), 3)
Beispiel #2
0
 def test_autoregressive_sample_transformerlm_tfnp(self):
     with fastmath.use_backend(fastmath.Backend.TFNP):
         model = models.TransformerLM(10,
                                      d_model=32,
                                      d_ff=64,
                                      n_layers=1,
                                      n_heads=2,
                                      mode='predict')
         model.init(shapes.ShapeDtype((1, 1), dtype=np.int32))
         s1 = decoding.autoregressive_sample(model,
                                             batch_size=1,
                                             eos_id=-1,
                                             max_length=10)
         self.assertEqual(s1.shape[0], 1)
         self.assertEqual(s1.shape[1], 10)
         batch_per_device = 2 // fastmath.device_count()
         model.init(shapes.ShapeDtype((batch_per_device, 1),
                                      dtype=np.int32))
         s2 = decoding.autoregressive_sample(model,
                                             batch_size=2,
                                             max_length=10)
         self.assertEqual(s2.shape[0], 2)
         self.assertLess(s2.shape[1], 11)
         model.init(shapes.ShapeDtype((1, 1), dtype=np.int32))
         prefix = np.array([[1, 2, 3]])
         s3 = decoding.autoregressive_sample(model,
                                             prefix,
                                             eos_id=-1,
                                             max_length=10,
                                             batch_size=1)
         self.assertEqual(s3.shape[0], 1)
         self.assertEqual(s3.shape[1], 10)
Beispiel #3
0
    def test_training_loop_cartpole_serialized_init_from_world_model(
            self, two_towers):
        gin.bind_parameter('BoxSpaceSerializer.precision', 1)

        transformer_kwargs = {
            'd_model': 1,
            'd_ff': 1,
            'n_layers': 1,
            'n_heads': 1,
            'max_len': 128,
        }
        obs_serializer = space_serializer.create(gym.spaces.MultiDiscrete(
            [2, 2]),
                                                 vocab_size=4)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=4)
        model_fn = lambda mode: serialization_utils.SerializedModel(  # pylint: disable=g-long-lambda
            seq_model=models.TransformerLM(
                mode=mode, vocab_size=4, **transformer_kwargs),
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )
        with self.tmp_dir() as output_dir:
            model_dir = os.path.join(output_dir, 'model')

            def dummy_stream(_):
                while True:
                    obs = np.zeros((1, 2, 2), dtype=np.int32)
                    act = np.zeros((1, 1), dtype=np.int32)
                    mask = np.ones_like(obs)
                    yield (obs, act, obs, mask)

            inputs = trax_inputs.Inputs(train_stream=dummy_stream,
                                        eval_stream=dummy_stream)
            inputs._input_shape = ((2, 2), (1, ))  # pylint: disable=protected-access
            inputs._input_dtype = (np.int32, np.int32)  # pylint: disable=protected-access

            # Initialize a world model checkpoint by running the trainer.
            trainer_lib.train(
                model_dir,
                model=model_fn,
                inputs=inputs,
                steps=1,
                eval_steps=1,
                has_weights=True,
            )

            policy_dir = os.path.join(output_dir, 'policy')
            trainer = self._make_trainer(
                train_env=self.get_wrapped_env('CartPole-v0', 2),
                eval_env=self.get_wrapped_env('CartPole-v0', 2),
                output_dir=policy_dir,
                model=functools.partial(models.TransformerDecoder,
                                        **transformer_kwargs),
                policy_and_value_vocab_size=4,
                init_policy_from_world_model_output_dir=model_dir,
                policy_and_value_two_towers=two_towers,
            )
            trainer.training_loop(n_epochs=2)
Beispiel #4
0
 def inner_model(mode):
     return models.TransformerLM(
         mode=mode,
         vocab_size=vocab_size,
         d_model=2,
         d_ff=4,
         n_layers=1,
         n_heads=1,
     )
Beispiel #5
0
 def test_autoregressive_sample_transformerlm_quality(self):
   pred_model = models.TransformerLM(
       d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2,
       n_layers=2, vocab_size=13, mode='predict')
   shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
   model_path = os.path.join(_TESTDATA, 'transformerlm_copy.pkl.gz')
   pred_model.init_from_file(model_path, weights_only=True,
                             input_signature=(shape11, shape11))
   inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32)
   s = decoding.autoregressive_sample(pred_model, inputs,
                                      max_length=6, temperature=0.0)
   self.assertEqual(str(s[0]), '[3 7 5 3 2 4]')
 def model(mode):
     return serialization_utils.SerializedModel(
         trax_models.TransformerLM(mode=mode,
                                   vocab_size=vocab_size,
                                   d_model=16,
                                   d_ff=8,
                                   n_layers=1,
                                   n_heads=1),
         observation_serializer=srl,
         action_serializer=srl,
         significance_decay=0.9,
     )
Beispiel #7
0
 def test_autoregressive_sample_transformerlm_quality_eval(self):
     eval_model = models.TransformerLM(d_model=64,
                                       d_ff=128,
                                       dropout=0.05,
                                       max_len=256,
                                       n_heads=2,
                                       n_layers=2,
                                       vocab_size=13,
                                       mode='eval')
     model_path = os.path.join(_TESTDATA, 'transformerlm_copy.pkl.gz')
     eval_model.init_from_file(model_path)
     inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32)
     s = decoding.autoregressive_sample(eval_model,
                                        inputs,
                                        eval_mode=True,
                                        max_length=6,
                                        temperature=0.0)
     self.assertEqual(str(s[0]), '[3 7 5 3 2 4]')