def initialize_environments(self, history_stream, batch_size=1, parallelism=1): """Initializes the environments. Args: history_stream: Iterator yielding batches of initial input data for the model. The format is implementation-specific. batch_size: (int) Number of environments in a batch. parallelism: (int) Unused. """ del parallelism trax_state = trax.restore_state(self._output_dir) # TODO(lukaszkaiser): both model state and parameters by default include # the loss layer. Currently, we access the pure-model parameters by just # indexing, [0] here. But we should make it more explicit in a better API. model_params = trax_state.opt_state.params[0] self._model_state = trax_state.model_state[0] def predict_fn(inputs, rng): (output, self._model_state) = self._model_predict( inputs, params=model_params, state=self._model_state, rng=rng ) return output self._predict_fn = predict_fn self._history_stream = history_stream self._steps = np.zeros(batch_size, dtype=np.int32)
def initialize_environments(self, history_stream, batch_size=1, parallelism=1): """Initializes the environments. Args: history_stream: Iterator yielding batches of initial input data for the model. The format is implementation-specific. batch_size: (int) Number of environments in a batch. parallelism: (int) Unused. """ del parallelism trax_state = trax.restore_state(self._output_dir) model_params = trax_state.opt_state.params # For initializing model state and resetting it. self._model_state_override = trax_state.model_state def predict_fn(*args, **kwargs): kwargs["params"] = model_params if self._model_state_override is not None: kwargs["state"] = self._model_state_override return self._model_predict(*args, **kwargs) self._predict_fn = predict_fn self._history_stream = history_stream self._steps = np.zeros(batch_size, dtype=np.int32)
def initialize_environments(self, history_stream, batch_size=1, parallelism=1): """Initializes the environments. Args: history_stream: Iterator yielding batches of initial input data for the model. The format is implementation-specific. batch_size: (int) Number of environments in a batch. parallelism: (int) Unused. """ del parallelism trax_state = trax.restore_state(self._output_dir) model_params = trax_state.opt_state.params self._model_state = trax_state.model_state def predict_fn(inputs, rng): (output, self._model_state) = self._model_predict(inputs, params=model_params, state=self._model_state, rng=rng) return output self._predict_fn = predict_fn self._history_stream = history_stream self._steps = np.zeros(batch_size, dtype=np.int32)
def initialize_environments(self, initial_observation_stream, batch_size=1, parallelism=1): """Initializes the environments. Args: initial_observation_stream: Iterator yielding batches of initial observations for the model. batch_size: (int) Number of environments in a batch. parallelism: (int) Unused. """ del parallelism model_state = trax.restore_state(self._output_dir) self._model_params = model_state.opt_state.params self._initial_observation_stream = initial_observation_stream self._history = None self._steps = np.zeros(batch_size)
def test_inits_policy_by_world_model_checkpoint(self): transformer_kwargs = { "d_model": 1, "d_ff": 1, "n_layers": 1, "n_heads": 1, "max_len": 128, "mode": "train", } rng = jax_random.PRNGKey(123) init_kwargs = { "input_shapes": (1, 1), "input_dtype": np.int32, "rng": rng, } model = models.TransformerLM(vocab_size=4, **transformer_kwargs) (model_params, _) = model.initialize_once(**init_kwargs) policy = ppo.policy_and_value_net( n_actions=3, n_controls=2, vocab_size=4, bottom_layers_fn=functools.partial( models.TransformerDecoder, **transformer_kwargs ), two_towers=False, ) (policy_params, policy_state) = policy.initialize_once(**init_kwargs) output_dir = self.get_temp_dir() # Initialize state by restoring from a nonexistent checkpoint. trax_state = trax.restore_state(output_dir) trax_state = trax_state._replace(opt_state=(model_params, None)) # Save world model parameters. trax.save_state(trax_state, output_dir) # Initialize policy parameters from world model parameters. new_policy_params = ppo.init_policy_from_world_model_checkpoint( policy_params, output_dir ) # Try to run the policy with new parameters. observations = np.zeros((1, 100), dtype=np.int32) policy(observations, params=new_policy_params, state=policy_state, rng=rng)
def initialize_environments(self, history_stream, batch_size=1, parallelism=1): """Initializes the environments. Args: history_stream: Iterator yielding batches of initial input data for the model. The format is implementation-specific. batch_size: (int) Number of environments in a batch. parallelism: (int) Unused. """ del parallelism model_state = trax.restore_state(self._output_dir) model_params = model_state.opt_state.params self._predict_fn = functools.partial( self._model_predict, params=model_params, ) self._history_stream = history_stream self._steps = np.zeros(batch_size, dtype=np.int32)