def model(self): return lambda mode: serialization_utils.SerializedModel( # pylint: disable=g-long-lambda seq_model=self._model(mode=mode), observation_serializer=self._obs_serializer, action_serializer=self._action_serializer, significance_decay=self._significance_decay, )
def test_serialized_model_continuous(self): precision = 3 gin.bind_parameter('BoxSpaceSerializer.precision', precision) vocab_size = 32 obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0, 0]]]) act = np.array([[0, 1, 0]]) mask = np.array([[1, 1, 1, 0]]) obs_serializer = space_serializer.create(gym.spaces.Box(shape=(2, ), low=-2, high=2), vocab_size=vocab_size) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) serialized_model = serialization_utils.SerializedModel( TestModel(extra_dim=vocab_size), # pylint: disable=no-value-for-parameter observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) example = (obs, act, obs, mask) serialized_model.init(shapes.signature(example)) (obs_logits, obs_repr, weights) = serialized_model(example) self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size, )) self.assertEqual(obs_repr.shape, (1, obs.shape[1], obs.shape[2] * precision)) self.assertEqual(obs_repr.shape, weights.shape)
def test_extract_inner_model(self): vocab_size = 3 inner_model = transformer.TransformerLM(vocab_size=vocab_size, d_model=2, d_ff=2, n_layers=0) obs_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) serialized_model = serialization_utils.SerializedModel( inner_model, observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) obs_sig = shapes.ShapeDtype((1, 2)) act_sig = shapes.ShapeDtype((1, 1)) (weights, state) = serialized_model.init(input_signature=(obs_sig, act_sig, obs_sig, obs_sig), ) (inner_weights, inner_state) = map(serialization_utils.extract_inner_model, (weights, state)) inner_model(np.array([[0]]), weights=inner_weights, state=inner_state)
def model(mode): return serialization_utils.SerializedModel( inner_model(mode), observation_serializer=srl, action_serializer=srl, significance_decay=0.7, )
def test_training_loop_cartpole_serialized_init_from_world_model( self, two_towers): gin.bind_parameter('BoxSpaceSerializer.precision', 1) transformer_kwargs = { 'd_model': 1, 'd_ff': 1, 'n_layers': 1, 'n_heads': 1, 'max_len': 128, } obs_serializer = space_serializer.create(gym.spaces.MultiDiscrete( [2, 2]), vocab_size=4) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=4) model_fn = lambda mode: serialization_utils.SerializedModel( # pylint: disable=g-long-lambda seq_model=models.TransformerLM( mode=mode, vocab_size=4, **transformer_kwargs), observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) with self.tmp_dir() as output_dir: model_dir = os.path.join(output_dir, 'model') def dummy_stream(_): while True: obs = np.zeros((1, 2, 2), dtype=np.int32) act = np.zeros((1, 1), dtype=np.int32) mask = np.ones_like(obs) yield (obs, act, obs, mask) inputs = trax_inputs.Inputs(train_stream=dummy_stream, eval_stream=dummy_stream) inputs._input_shape = ((2, 2), (1, )) # pylint: disable=protected-access inputs._input_dtype = (np.int32, np.int32) # pylint: disable=protected-access # Initialize a world model checkpoint by running the trainer. trainer_lib.train( model_dir, model=model_fn, inputs=inputs, steps=1, eval_steps=1, has_weights=True, ) policy_dir = os.path.join(output_dir, 'policy') trainer = self._make_trainer( train_env=self.get_wrapped_env('CartPole-v0', 2), eval_env=self.get_wrapped_env('CartPole-v0', 2), output_dir=policy_dir, model=functools.partial(models.TransformerDecoder, **transformer_kwargs), policy_and_value_vocab_size=4, init_policy_from_world_model_output_dir=model_dir, policy_and_value_two_towers=two_towers, ) trainer.training_loop(n_epochs=2)
def test_serialized_model_extracts_seq_model_weights_and_state(self): vocab_size = 3 seq_model_fn = functools.partial( transformer.TransformerLM, vocab_size=vocab_size, d_model=2, d_ff=2, n_layers=0, ) seq_model = seq_model_fn(mode='eval') obs_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) serialized_model = serialization_utils.SerializedModel( seq_model_fn, observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) obs_sig = shapes.ShapeDtype((1, 2)) act_sig = shapes.ShapeDtype((1, 1)) serialized_model.init(input_signature=(obs_sig, act_sig, obs_sig, obs_sig)) seq_model.weights = serialized_model.seq_model_weights seq_model.state = serialized_model.seq_model_state # Run the model to check if the weights and state have correct structure. seq_model(jnp.array([[0]]))
def test_serialized_model_discrete(self): vocab_size = 3 obs = np.array([[[0, 1], [1, 1], [1, 0], [0, 0]]]) act = np.array([[1, 0, 0]]) mask = np.array([[1, 1, 1, 0]]) test_model_inputs = [] # pylint: disable=invalid-name def TestModelSavingInputs(): def f(inputs): # Save the inputs for a later check. test_model_inputs.append(inputs) # Change type to np.float32 and add the logit dimension. return jnp.broadcast_to( inputs.astype(np.float32)[:, :, None], inputs.shape + (vocab_size, )) return layers_base.Fn('TestModelSavingInputs', f) # pylint: enable=invalid-name obs_serializer = space_serializer.create(gym.spaces.MultiDiscrete( [2, 2]), vocab_size=vocab_size) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) serialized_model = serialization_utils.SerializedModel( TestModelSavingInputs(), # pylint: disable=no-value-for-parameter observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) example = (obs, act, obs, mask) serialized_model.init(shapes.signature(example)) (obs_logits, obs_repr, weights) = serialized_model(example) # Check that the model has been called with the correct input. np.testing.assert_array_equal( # The model is called multiple times for determining shapes etc. # Check the last saved input - that should be the actual concrete array # calculated during the forward pass. test_model_inputs[-1], # Should be serialized observations and actions interleaved. [[0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0]], ) # Check the output shape. self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size, )) # Check that obs_logits are the same as obs_repr, just broadcasted over the # logit dimension. np.testing.assert_array_equal(np.min(obs_logits, axis=-1), obs_repr) np.testing.assert_array_equal(np.max(obs_logits, axis=-1), obs_repr) # Check that the observations are correct. np.testing.assert_array_equal(obs_repr, obs) # Check weights. np.testing.assert_array_equal( weights, [[[1., 1.], [1., 1.], [1., 1.], [0., 0.]]], )
def make_serialized_model(seq_model, space, vocab_size): srl = space_serializer.create(space, vocab_size) return serialization_utils.SerializedModel( functools.partial(seq_model, vocab_size=vocab_size), observation_serializer=srl, action_serializer=srl, significance_decay=0.7, )
def model(mode): return serialization_utils.SerializedModel( trax_models.TransformerLM(mode=mode, vocab_size=vocab_size, d_model=16, d_ff=8, n_layers=1, n_heads=1), observation_serializer=srl, action_serializer=srl, significance_decay=0.9, )
def test_serialized_model_continuous(self): precision = 3 gin.bind_parameter('BoxSpaceSerializer.precision', precision) vocab_size = 32 obs = onp.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0, 0]]]) act = onp.array([[0, 1, 0]]) mask = onp.array([[1, 1, 1, 0]]) @layers_base.layer() def TestModel(inputs, **unused_kwargs): # Change type to onp.float32 and add the logit dimension. return np.broadcast_to( inputs.astype(onp.float32)[:, :, None], inputs.shape + (vocab_size, )) obs_serializer = space_serializer.create(gym.spaces.Box(shape=(2, ), low=-2, high=2), vocab_size=vocab_size) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) serialized_model = serialization_utils.SerializedModel( TestModel(), # pylint: disable=no-value-for-parameter observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) example = (obs, act, obs, mask) serialized_model.init(shapes.signature(example)) (obs_logits, obs_repr, weights) = serialized_model(example) self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size, )) self.assertEqual(obs_repr.shape, (1, obs.shape[1], obs.shape[2] * precision)) self.assertEqual(obs_repr.shape, weights.shape)