def test_extract_inner_model(self): vocab_size = 3 inner_model = transformer.TransformerLM(vocab_size=vocab_size, d_model=2, d_ff=2, n_layers=0) obs_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) serialized_model = serialization_utils.SerializedModel( inner_model, observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) obs_sig = shapes.ShapeDtype((1, 2)) act_sig = shapes.ShapeDtype((1, 1)) (weights, state) = serialized_model.init(input_signature=(obs_sig, act_sig, obs_sig, obs_sig), ) (inner_weights, inner_state) = map(serialization_utils.extract_inner_model, (weights, state)) inner_model(np.array([[0]]), weights=inner_weights, state=inner_state)
def test_training_loop_cartpole_serialized_init_from_world_model( self, two_towers): gin.bind_parameter('BoxSpaceSerializer.precision', 1) transformer_kwargs = { 'd_model': 1, 'd_ff': 1, 'n_layers': 1, 'n_heads': 1, 'max_len': 128, } obs_serializer = space_serializer.create(gym.spaces.MultiDiscrete( [2, 2]), vocab_size=4) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=4) model_fn = lambda mode: serialization_utils.SerializedModel( # pylint: disable=g-long-lambda seq_model=models.TransformerLM( mode=mode, vocab_size=4, **transformer_kwargs), observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) with self.tmp_dir() as output_dir: model_dir = os.path.join(output_dir, 'model') def dummy_stream(_): while True: obs = np.zeros((1, 2, 2), dtype=np.int32) act = np.zeros((1, 1), dtype=np.int32) mask = np.ones_like(obs) yield (obs, act, obs, mask) inputs = trax_inputs.Inputs(train_stream=dummy_stream, eval_stream=dummy_stream) inputs._input_shape = ((2, 2), (1, )) # pylint: disable=protected-access inputs._input_dtype = (np.int32, np.int32) # pylint: disable=protected-access # Initialize a world model checkpoint by running the trainer. trainer_lib.train( model_dir, model=model_fn, inputs=inputs, steps=1, eval_steps=1, has_weights=True, ) policy_dir = os.path.join(output_dir, 'policy') trainer = self._make_trainer( train_env=self.get_wrapped_env('CartPole-v0', 2), eval_env=self.get_wrapped_env('CartPole-v0', 2), output_dir=policy_dir, model=functools.partial(models.TransformerDecoder, **transformer_kwargs), policy_and_value_vocab_size=4, init_policy_from_world_model_output_dir=model_dir, policy_and_value_two_towers=two_towers, ) trainer.training_loop(n_epochs=2)
def test_serialized_model_continuous(self): precision = 3 gin.bind_parameter('BoxSpaceSerializer.precision', precision) vocab_size = 32 obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0, 0]]]) act = np.array([[0, 1, 0]]) mask = np.array([[1, 1, 1, 0]]) obs_serializer = space_serializer.create(gym.spaces.Box(shape=(2, ), low=-2, high=2), vocab_size=vocab_size) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) serialized_model = serialization_utils.SerializedModel( TestModel(extra_dim=vocab_size), # pylint: disable=no-value-for-parameter observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) example = (obs, act, obs, mask) serialized_model.init(shapes.signature(example)) (obs_logits, obs_repr, weights) = serialized_model(example) self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size, )) self.assertEqual(obs_repr.shape, (1, obs.shape[1], obs.shape[2] * precision)) self.assertEqual(obs_repr.shape, weights.shape)
def test_serialized_model_extracts_seq_model_weights_and_state(self): vocab_size = 3 seq_model_fn = functools.partial( transformer.TransformerLM, vocab_size=vocab_size, d_model=2, d_ff=2, n_layers=0, ) seq_model = seq_model_fn(mode='eval') obs_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) serialized_model = serialization_utils.SerializedModel( seq_model_fn, observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) obs_sig = shapes.ShapeDtype((1, 2)) act_sig = shapes.ShapeDtype((1, 1)) serialized_model.init(input_signature=(obs_sig, act_sig, obs_sig, obs_sig)) seq_model.weights = serialized_model.seq_model_weights seq_model.state = serialized_model.seq_model_state # Run the model to check if the weights and state have correct structure. seq_model(jnp.array([[0]]))
def test_serialized_model_discrete(self): vocab_size = 3 obs = np.array([[[0, 1], [1, 1], [1, 0], [0, 0]]]) act = np.array([[1, 0, 0]]) mask = np.array([[1, 1, 1, 0]]) test_model_inputs = [] # pylint: disable=invalid-name def TestModelSavingInputs(): def f(inputs): # Save the inputs for a later check. test_model_inputs.append(inputs) # Change type to np.float32 and add the logit dimension. return jnp.broadcast_to( inputs.astype(np.float32)[:, :, None], inputs.shape + (vocab_size, )) return layers_base.Fn('TestModelSavingInputs', f) # pylint: enable=invalid-name obs_serializer = space_serializer.create(gym.spaces.MultiDiscrete( [2, 2]), vocab_size=vocab_size) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) serialized_model = serialization_utils.SerializedModel( TestModelSavingInputs(), # pylint: disable=no-value-for-parameter observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) example = (obs, act, obs, mask) serialized_model.init(shapes.signature(example)) (obs_logits, obs_repr, weights) = serialized_model(example) # Check that the model has been called with the correct input. np.testing.assert_array_equal( # The model is called multiple times for determining shapes etc. # Check the last saved input - that should be the actual concrete array # calculated during the forward pass. test_model_inputs[-1], # Should be serialized observations and actions interleaved. [[0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0]], ) # Check the output shape. self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size, )) # Check that obs_logits are the same as obs_repr, just broadcasted over the # logit dimension. np.testing.assert_array_equal(np.min(obs_logits, axis=-1), obs_repr) np.testing.assert_array_equal(np.max(obs_logits, axis=-1), obs_repr) # Check that the observations are correct. np.testing.assert_array_equal(obs_repr, obs) # Check weights. np.testing.assert_array_equal( weights, [[[1., 1.], [1., 1.], [1., 1.], [0., 0.]]], )
def __init__(self, model, reward_fn, done_fn, vocab_size, max_trajectory_length, observation_space, action_space, significance_decay=1.0, **kwargs): """Initializes the env. Args: model: trax model to use for simulation. It's assumed to take keyword arguments vocab_size and mode, where vocab_size is the number of symbols in the vocabulary and mode is either 'train' or 'eval'. reward_fn: Function (previous_observation, current_observation) -> reward. done_fn: Function (previous_observation, current_observation) -> done. vocab_size: (int) Number of symbols in the vocabulary. max_trajectory_length: (int) Maximum length of a trajectory unrolled from the model. observation_space: (gym.Space) Observation space. action_space: (gym.Space) Action space. significance_decay: (float) Decay for training weights of progressively less significant symbols in the representation. **kwargs: (dict) Keyword arguments passed to the base class. """ self._reward_fn = reward_fn self._done_fn = done_fn self._vocab_size = vocab_size self._max_trajectory_length = max_trajectory_length self._significance_decay = significance_decay self._steps = None self._observation_space = None self._action_space = None self._last_observations = None self._obs_serializer = space_serializer.create(observation_space, self._vocab_size) self._action_serializer = space_serializer.create( action_space, self._vocab_size) self._obs_repr_length = self._obs_serializer.representation_length self._act_repr_length = self._action_serializer.representation_length self._step_repr_length = self._obs_repr_length + self._act_repr_length # We assume that the model takes vocab_size as an argument (e.g. # TransformerLM). model = functools.partial(model, vocab_size=vocab_size) super(SerializedSequenceSimulatedEnvProblem, self).__init__(model=model, observation_space=observation_space, action_space=action_space, **kwargs)
def _init_serialization(self, vocab_size): obs_serializer = space_serializer.create( self.train_env.observation_space, vocab_size=vocab_size) act_serializer = space_serializer.create(self.train_env.action_space, vocab_size=vocab_size) repr_length = (obs_serializer.representation_length + act_serializer.representation_length) * ( self._max_timestep + 1) return { "observation_serializer": obs_serializer, "action_serializer": act_serializer, "representation_length": repr_length, }
def init_serialization(vocab_size, observation_space, action_space, n_timesteps): """Initializes serialization keyword arguments.""" obs_serializer = space_serializer.create(observation_space, vocab_size=vocab_size) act_serializer = space_serializer.create(action_space, vocab_size=vocab_size) repr_length = (obs_serializer.representation_length + act_serializer.representation_length) * n_timesteps return { 'observation_serializer': obs_serializer, 'action_serializer': act_serializer, 'representation_length': repr_length, }
def test_significance_map(self): gin.bind_parameter('BoxSpaceSerializer.precision', 3) significance_map = serialization_utils.significance_map( observation_serializer=space_serializer.create(gym.spaces.Box( low=0, high=1, shape=(2, )), vocab_size=2), action_serializer=space_serializer.create( gym.spaces.MultiDiscrete(nvec=[2, 2]), vocab_size=2), representation_length=20, ) np.testing.assert_array_equal( significance_map, # obs1, act1, obs2, act2, obs3 cut after 4th symbol. [0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0], )
def test_rewards_to_actions_map(self): rewards = onp.array([1, 2, 3]) r2a_map = serialization_utils.rewards_to_actions_map( observation_serializer=space_serializer.create( gym.spaces.MultiDiscrete(nvec=[2, 2, 2]), vocab_size=2), action_serializer=space_serializer.create( gym.spaces.MultiDiscrete(nvec=[2, 2]), vocab_size=2), n_timesteps=len(rewards), representation_length=16, ) broadcast_rewards = onp.dot(rewards, r2a_map) onp.testing.assert_array_equal( broadcast_rewards, # obs1, act1, obs2, act2, obs3 cut after 1st symbol. [0, 0, 0, 1, 1, 0, 0, 0, 2, 2, 0, 0, 0, 3, 3, 0], )
def make_serialized_model(seq_model, space, vocab_size): srl = space_serializer.create(space, vocab_size) return serialization_utils.SerializedModel( functools.partial(seq_model, vocab_size=vocab_size), observation_serializer=srl, action_serializer=srl, significance_decay=0.7, )
def setUp(self): super(SerializationTest, self).setUp() self._serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=2) self._repr_length = 100 self._serialization_utils_kwargs = { 'observation_serializer': self._serializer, 'action_serializer': self._serializer, 'representation_length': self._repr_length, }
def setUp(self): super().setUp() self._serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=2) self._repr_length = 100 self._serialization_utils_kwargs = { 'observation_serializer': self._serializer, 'action_serializer': self._serializer, 'representation_length': self._repr_length, } test_utils.ensure_flag('test_tmpdir')
def wrap_policy(seq_model, observation_space, action_space, vocab_size): # pylint: disable=invalid-name """Wraps a sequence model in either RawPolicy or SerializedPolicy. Args: seq_model: Trax sequence model. observation_space: Gym observation space. action_space: Gym action space. vocab_size: Either the number of symbols for a serialized policy, or None. Returns: RawPolicy if vocab_size is None, else SerializedPolicy. """ (n_controls, n_actions) = analyze_action_space(action_space) if vocab_size is None: policy_wrapper = RawPolicy else: obs_serializer = space_serializer.create(observation_space, vocab_size) act_serializer = space_serializer.create(action_space, vocab_size) policy_wrapper = functools.partial(SerializedPolicy, observation_serializer=obs_serializer, action_serializer=act_serializer) return policy_wrapper(seq_model, n_controls, n_actions)
def _make_space_and_serializer( self, low=-10, high=10, shape=(2, ), # Weird vocab_size to test that it doesn't only work with powers of 2. vocab_size=257, # Enough precision to represent float32s accurately. precision=4, ): gin.bind_parameter("BoxSpaceSerializer.precision", precision) space = gym.spaces.Box(low=low, high=high, shape=shape) serializer = space_serializer.create(space, vocab_size=vocab_size) return (space, serializer)
def test_serialized_model_continuous(self): precision = 3 gin.bind_parameter('BoxSpaceSerializer.precision', precision) vocab_size = 32 obs = onp.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0, 0]]]) act = onp.array([[0, 1, 0]]) mask = onp.array([[1, 1, 1, 0]]) @layers_base.layer() def TestModel(inputs, **unused_kwargs): # Change type to onp.float32 and add the logit dimension. return np.broadcast_to( inputs.astype(onp.float32)[:, :, None], inputs.shape + (vocab_size, )) obs_serializer = space_serializer.create(gym.spaces.Box(shape=(2, ), low=-2, high=2), vocab_size=vocab_size) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) serialized_model = serialization_utils.SerializedModel( TestModel(), # pylint: disable=no-value-for-parameter observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) example = (obs, act, obs, mask) serialized_model.init(shapes.signature(example)) (obs_logits, obs_repr, weights) = serialized_model(example) self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size, )) self.assertEqual(obs_repr.shape, (1, obs.shape[1], obs.shape[2] * precision)) self.assertEqual(obs_repr.shape, weights.shape)
def setUp(self): super(MultiDiscreteSpaceSerializerTest, self).setUp() self._space = gym.spaces.MultiDiscrete(nvec=[2, 2]) self._serializer = space_serializer.create(self._space, vocab_size=2)
def setUp(self): super(DiscreteSpaceSerializerTest, self).setUp() self._space = gym.spaces.Discrete(n=2) self._serializer = space_serializer.create(self._space, vocab_size=2)