def initialize_environments(self,
                              history_stream,
                              batch_size=1,
                              parallelism=1):
    """Initializes the environments.

    Args:
      history_stream: Iterator yielding batches of initial input data for the
        model. The format is implementation-specific.
      batch_size: (int) Number of environments in a batch.
      parallelism: (int) Unused.
    """
    del parallelism

    trax_state = trax.restore_state(self._output_dir)
    # TODO(lukaszkaiser): both model state and parameters by default include
    # the loss layer. Currently, we access the pure-model parameters by just
    # indexing, [0] here. But we should make it more explicit in a better API.
    model_params = trax_state.opt_state.params[0]
    self._model_state = trax_state.model_state[0]

    def predict_fn(inputs, rng):
      (output, self._model_state) = self._model_predict(
          inputs, params=model_params, state=self._model_state, rng=rng
      )
      return output

    self._predict_fn = predict_fn
    self._history_stream = history_stream
    self._steps = np.zeros(batch_size, dtype=np.int32)
示例#2
0
    def initialize_environments(self,
                                history_stream,
                                batch_size=1,
                                parallelism=1):
        """Initializes the environments.

    Args:
      history_stream: Iterator yielding batches of initial input data for the
        model. The format is implementation-specific.
      batch_size: (int) Number of environments in a batch.
      parallelism: (int) Unused.
    """
        del parallelism

        trax_state = trax.restore_state(self._output_dir)
        model_params = trax_state.opt_state.params

        # For initializing model state and resetting it.
        self._model_state_override = trax_state.model_state

        def predict_fn(*args, **kwargs):
            kwargs["params"] = model_params
            if self._model_state_override is not None:
                kwargs["state"] = self._model_state_override
            return self._model_predict(*args, **kwargs)

        self._predict_fn = predict_fn
        self._history_stream = history_stream

        self._steps = np.zeros(batch_size, dtype=np.int32)
    def initialize_environments(self,
                                history_stream,
                                batch_size=1,
                                parallelism=1):
        """Initializes the environments.

    Args:
      history_stream: Iterator yielding batches of initial input data for the
        model. The format is implementation-specific.
      batch_size: (int) Number of environments in a batch.
      parallelism: (int) Unused.
    """
        del parallelism

        trax_state = trax.restore_state(self._output_dir)
        model_params = trax_state.opt_state.params
        self._model_state = trax_state.model_state

        def predict_fn(inputs, rng):
            (output,
             self._model_state) = self._model_predict(inputs,
                                                      params=model_params,
                                                      state=self._model_state,
                                                      rng=rng)
            return output

        self._predict_fn = predict_fn
        self._history_stream = history_stream

        self._steps = np.zeros(batch_size, dtype=np.int32)
示例#4
0
    def initialize_environments(self,
                                initial_observation_stream,
                                batch_size=1,
                                parallelism=1):
        """Initializes the environments.

    Args:
      initial_observation_stream: Iterator yielding batches of initial
        observations for the model.
      batch_size: (int) Number of environments in a batch.
      parallelism: (int) Unused.
    """
        del parallelism

        model_state = trax.restore_state(self._output_dir)
        self._model_params = model_state.opt_state.params
        self._initial_observation_stream = initial_observation_stream

        self._history = None
        self._steps = np.zeros(batch_size)
 def test_inits_policy_by_world_model_checkpoint(self):
   transformer_kwargs = {
       "d_model": 1,
       "d_ff": 1,
       "n_layers": 1,
       "n_heads": 1,
       "max_len": 128,
       "mode": "train",
   }
   rng = jax_random.PRNGKey(123)
   init_kwargs = {
       "input_shapes": (1, 1),
       "input_dtype": np.int32,
       "rng": rng,
   }
   model = models.TransformerLM(vocab_size=4, **transformer_kwargs)
   (model_params, _) = model.initialize_once(**init_kwargs)
   policy = ppo.policy_and_value_net(
       n_actions=3,
       n_controls=2,
       vocab_size=4,
       bottom_layers_fn=functools.partial(
           models.TransformerDecoder, **transformer_kwargs
       ),
       two_towers=False,
   )
   (policy_params, policy_state) = policy.initialize_once(**init_kwargs)
   output_dir = self.get_temp_dir()
   # Initialize state by restoring from a nonexistent checkpoint.
   trax_state = trax.restore_state(output_dir)
   trax_state = trax_state._replace(opt_state=(model_params, None))
   # Save world model parameters.
   trax.save_state(trax_state, output_dir)
   # Initialize policy parameters from world model parameters.
   new_policy_params = ppo.init_policy_from_world_model_checkpoint(
       policy_params, output_dir
   )
   # Try to run the policy with new parameters.
   observations = np.zeros((1, 100), dtype=np.int32)
   policy(observations, params=new_policy_params, state=policy_state, rng=rng)
    def initialize_environments(self,
                                history_stream,
                                batch_size=1,
                                parallelism=1):
        """Initializes the environments.

    Args:
      history_stream: Iterator yielding batches of initial input data for the
        model. The format is implementation-specific.
      batch_size: (int) Number of environments in a batch.
      parallelism: (int) Unused.
    """
        del parallelism

        model_state = trax.restore_state(self._output_dir)
        model_params = model_state.opt_state.params
        self._predict_fn = functools.partial(
            self._model_predict,
            params=model_params,
        )
        self._history_stream = history_stream

        self._steps = np.zeros(batch_size, dtype=np.int32)