Exemple #1
0
 def model(self):
     return lambda mode: serialization_utils.SerializedModel(  # pylint: disable=g-long-lambda
         seq_model=self._model(mode=mode),
         observation_serializer=self._obs_serializer,
         action_serializer=self._action_serializer,
         significance_decay=self._significance_decay,
     )
    def test_serialized_model_continuous(self):
        precision = 3
        gin.bind_parameter('BoxSpaceSerializer.precision', precision)

        vocab_size = 32
        obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0, 0]]])
        act = np.array([[0, 1, 0]])
        mask = np.array([[1, 1, 1, 0]])

        obs_serializer = space_serializer.create(gym.spaces.Box(shape=(2, ),
                                                                low=-2,
                                                                high=2),
                                                 vocab_size=vocab_size)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        serialized_model = serialization_utils.SerializedModel(
            TestModel(extra_dim=vocab_size),  # pylint: disable=no-value-for-parameter
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )

        example = (obs, act, obs, mask)
        serialized_model.init(shapes.signature(example))

        (obs_logits, obs_repr, weights) = serialized_model(example)
        self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size, ))
        self.assertEqual(obs_repr.shape,
                         (1, obs.shape[1], obs.shape[2] * precision))
        self.assertEqual(obs_repr.shape, weights.shape)
Exemple #3
0
    def test_extract_inner_model(self):
        vocab_size = 3

        inner_model = transformer.TransformerLM(vocab_size=vocab_size,
                                                d_model=2,
                                                d_ff=2,
                                                n_layers=0)
        obs_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        serialized_model = serialization_utils.SerializedModel(
            inner_model,
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )

        obs_sig = shapes.ShapeDtype((1, 2))
        act_sig = shapes.ShapeDtype((1, 1))
        (weights,
         state) = serialized_model.init(input_signature=(obs_sig, act_sig,
                                                         obs_sig, obs_sig), )
        (inner_weights,
         inner_state) = map(serialization_utils.extract_inner_model,
                            (weights, state))
        inner_model(np.array([[0]]), weights=inner_weights, state=inner_state)
Exemple #4
0
 def model(mode):
     return serialization_utils.SerializedModel(
         inner_model(mode),
         observation_serializer=srl,
         action_serializer=srl,
         significance_decay=0.7,
     )
Exemple #5
0
    def test_training_loop_cartpole_serialized_init_from_world_model(
            self, two_towers):
        gin.bind_parameter('BoxSpaceSerializer.precision', 1)

        transformer_kwargs = {
            'd_model': 1,
            'd_ff': 1,
            'n_layers': 1,
            'n_heads': 1,
            'max_len': 128,
        }
        obs_serializer = space_serializer.create(gym.spaces.MultiDiscrete(
            [2, 2]),
                                                 vocab_size=4)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=4)
        model_fn = lambda mode: serialization_utils.SerializedModel(  # pylint: disable=g-long-lambda
            seq_model=models.TransformerLM(
                mode=mode, vocab_size=4, **transformer_kwargs),
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )
        with self.tmp_dir() as output_dir:
            model_dir = os.path.join(output_dir, 'model')

            def dummy_stream(_):
                while True:
                    obs = np.zeros((1, 2, 2), dtype=np.int32)
                    act = np.zeros((1, 1), dtype=np.int32)
                    mask = np.ones_like(obs)
                    yield (obs, act, obs, mask)

            inputs = trax_inputs.Inputs(train_stream=dummy_stream,
                                        eval_stream=dummy_stream)
            inputs._input_shape = ((2, 2), (1, ))  # pylint: disable=protected-access
            inputs._input_dtype = (np.int32, np.int32)  # pylint: disable=protected-access

            # Initialize a world model checkpoint by running the trainer.
            trainer_lib.train(
                model_dir,
                model=model_fn,
                inputs=inputs,
                steps=1,
                eval_steps=1,
                has_weights=True,
            )

            policy_dir = os.path.join(output_dir, 'policy')
            trainer = self._make_trainer(
                train_env=self.get_wrapped_env('CartPole-v0', 2),
                eval_env=self.get_wrapped_env('CartPole-v0', 2),
                output_dir=policy_dir,
                model=functools.partial(models.TransformerDecoder,
                                        **transformer_kwargs),
                policy_and_value_vocab_size=4,
                init_policy_from_world_model_output_dir=model_dir,
                policy_and_value_two_towers=two_towers,
            )
            trainer.training_loop(n_epochs=2)
    def test_serialized_model_extracts_seq_model_weights_and_state(self):
        vocab_size = 3

        seq_model_fn = functools.partial(
            transformer.TransformerLM,
            vocab_size=vocab_size,
            d_model=2,
            d_ff=2,
            n_layers=0,
        )
        seq_model = seq_model_fn(mode='eval')
        obs_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        serialized_model = serialization_utils.SerializedModel(
            seq_model_fn,
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )

        obs_sig = shapes.ShapeDtype((1, 2))
        act_sig = shapes.ShapeDtype((1, 1))
        serialized_model.init(input_signature=(obs_sig, act_sig, obs_sig,
                                               obs_sig))
        seq_model.weights = serialized_model.seq_model_weights
        seq_model.state = serialized_model.seq_model_state
        # Run the model to check if the weights and state have correct structure.
        seq_model(jnp.array([[0]]))
    def test_serialized_model_discrete(self):
        vocab_size = 3
        obs = np.array([[[0, 1], [1, 1], [1, 0], [0, 0]]])
        act = np.array([[1, 0, 0]])
        mask = np.array([[1, 1, 1, 0]])

        test_model_inputs = []

        # pylint: disable=invalid-name
        def TestModelSavingInputs():
            def f(inputs):
                # Save the inputs for a later check.
                test_model_inputs.append(inputs)
                # Change type to np.float32 and add the logit dimension.
                return jnp.broadcast_to(
                    inputs.astype(np.float32)[:, :, None],
                    inputs.shape + (vocab_size, ))

            return layers_base.Fn('TestModelSavingInputs', f)
            # pylint: enable=invalid-name

        obs_serializer = space_serializer.create(gym.spaces.MultiDiscrete(
            [2, 2]),
                                                 vocab_size=vocab_size)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        serialized_model = serialization_utils.SerializedModel(
            TestModelSavingInputs(),  # pylint: disable=no-value-for-parameter
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )

        example = (obs, act, obs, mask)
        serialized_model.init(shapes.signature(example))

        (obs_logits, obs_repr, weights) = serialized_model(example)
        # Check that the model has been called with the correct input.
        np.testing.assert_array_equal(
            # The model is called multiple times for determining shapes etc.
            # Check the last saved input - that should be the actual concrete array
            # calculated during the forward pass.
            test_model_inputs[-1],
            # Should be serialized observations and actions interleaved.
            [[0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0]],
        )
        # Check the output shape.
        self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size, ))
        # Check that obs_logits are the same as obs_repr, just broadcasted over the
        # logit dimension.
        np.testing.assert_array_equal(np.min(obs_logits, axis=-1), obs_repr)
        np.testing.assert_array_equal(np.max(obs_logits, axis=-1), obs_repr)
        # Check that the observations are correct.
        np.testing.assert_array_equal(obs_repr, obs)
        # Check weights.
        np.testing.assert_array_equal(
            weights,
            [[[1., 1.], [1., 1.], [1., 1.], [0., 0.]]],
        )
Exemple #8
0
def make_serialized_model(seq_model, space, vocab_size):
    srl = space_serializer.create(space, vocab_size)
    return serialization_utils.SerializedModel(
        functools.partial(seq_model, vocab_size=vocab_size),
        observation_serializer=srl,
        action_serializer=srl,
        significance_decay=0.7,
    )
 def model(mode):
     return serialization_utils.SerializedModel(
         trax_models.TransformerLM(mode=mode,
                                   vocab_size=vocab_size,
                                   d_model=16,
                                   d_ff=8,
                                   n_layers=1,
                                   n_heads=1),
         observation_serializer=srl,
         action_serializer=srl,
         significance_decay=0.9,
     )
Exemple #10
0
    def test_serialized_model_continuous(self):
        precision = 3
        gin.bind_parameter('BoxSpaceSerializer.precision', precision)

        vocab_size = 32
        obs = onp.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0, 0]]])
        act = onp.array([[0, 1, 0]])
        mask = onp.array([[1, 1, 1, 0]])

        @layers_base.layer()
        def TestModel(inputs, **unused_kwargs):
            # Change type to onp.float32 and add the logit dimension.
            return np.broadcast_to(
                inputs.astype(onp.float32)[:, :, None],
                inputs.shape + (vocab_size, ))

        obs_serializer = space_serializer.create(gym.spaces.Box(shape=(2, ),
                                                                low=-2,
                                                                high=2),
                                                 vocab_size=vocab_size)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        serialized_model = serialization_utils.SerializedModel(
            TestModel(),  # pylint: disable=no-value-for-parameter
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )

        example = (obs, act, obs, mask)
        serialized_model.init(shapes.signature(example))
        (obs_logits, obs_repr, weights) = serialized_model(example)
        self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size, ))
        self.assertEqual(obs_repr.shape,
                         (1, obs.shape[1], obs.shape[2] * precision))
        self.assertEqual(obs_repr.shape, weights.shape)