コード例 #1
0
 def test_convert_checkpoint(self):
     inputs = jnp.ones([2, 5, 5, 1])
     rng = jax.random.PRNGKey(0)
     # pre-Linen.
     with flax.nn.stateful() as model_state:
         y, params = ModelPreLinen.init(rng, inputs)
     pre_linen_optimizer = flax.optim.GradientDescent(0.1).create(params)
     train_state = TrainState(optimizer=pre_linen_optimizer,
                              model_state=model_state)
     state_dict = flax.serialization.to_state_dict(train_state)
     # Linen.
     model = Model()
     variables = model.init(rng, inputs)
     optimizer = flax.optim.GradientDescent(0.1).create(variables['params'])
     optimizer = optimizer.restore_state(
         flax.core.unfreeze(
             checkpoints.convert_pre_linen(state_dict['optimizer'])))
     optimizer = optimizer.apply_gradient(variables['params'])
     batch_stats = checkpoints.convert_pre_linen(
         flax.traverse_util.unflatten_dict({
             tuple(k.split('/')[1:]): v
             for k, v in model_state.as_dict().items()
         }))
     y, updated_state = model.apply(dict(params=optimizer.target,
                                         batch_stats=batch_stats),
                                    inputs,
                                    mutable=['batch_stats'])
     del y, updated_state  # not used.
コード例 #2
0
    def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary):
        """Restores the agent from a checkpoint.

    Restores the agent's Python objects to those specified in bundle_dictionary,
    and restores the TensorFlow objects to those specified in the
    checkpoint_dir. If the checkpoint_dir does not exist, will not reset the
      agent's state.

    Args:
      checkpoint_dir: str, path to the checkpoint saved.
      iteration_number: int, checkpoint version, used when restoring the replay
        buffer.
      bundle_dictionary: dict, containing additional Python objects owned by
        the agent.

    Returns:
      bool, True if unbundling was successful.
    """
        try:
            # self._replay.load() will throw a NotFoundError if it does not find all
            # the necessary files.
            self._replay.load(checkpoint_dir, iteration_number)
        except tf.errors.NotFoundError:
            if not self.allow_partial_reload:
                # If we don't allow partial reloads, we will return False.
                return False
            logging.warning('Unable to reload replay buffer!')
        if bundle_dictionary is not None:
            self.state = bundle_dictionary['state']
            self.training_steps = bundle_dictionary['training_steps']
            if isinstance(bundle_dictionary['online_params'], core.FrozenDict):
                self.online_params = bundle_dictionary['online_params']
                self.target_network_params = bundle_dictionary['target_params']
            else:  # Load pre-linen checkpoint.
                self.online_params = core.FrozenDict({
                    'params':
                    checkpoints.convert_pre_linen(
                        bundle_dictionary['online_params']).unfreeze()
                })
                self.target_network_params = core.FrozenDict({
                    'params':
                    checkpoints.convert_pre_linen(
                        bundle_dictionary['target_params']).unfreeze()
                })
            # We recreate the optimizer with the new online weights.
            self.optimizer = create_optimizer(self._optimizer_name)
            if 'optimizer_state' in bundle_dictionary:
                self.optimizer_state = bundle_dictionary['optimizer_state']
            else:
                self.optimizer_state = self.optimizer.init(self.online_params)
        elif not self.allow_partial_reload:
            return False
        else:
            logging.warning("Unable to reload the agent's parameters!")
        return True
コード例 #3
0
 def test_convert_pre_linen(self):
     params = checkpoints.convert_pre_linen({
         'mod_0': {
             'submod1_0': {},
             'submod2_1': {},
             'submod1_2': {},
         },
         'mod2_2': {
             'submod2_2_0': {}
         },
         'mod2_11': {
             'submod2_11_0': {}
         },
         'mod2_1': {
             'submod2_1_0': {}
         },
     })
     self.assertDictEqual(
         core.unfreeze(params), {
             'mod_0': {
                 'submod1_0': {},
                 'submod1_1': {},
                 'submod2_0': {},
             },
             'mod2_0': {
                 'submod2_1_0': {}
             },
             'mod2_1': {
                 'submod2_2_0': {}
             },
             'mod2_2': {
                 'submod2_11_0': {}
             },
         })
コード例 #4
0
def load(path):
    """Loads params from a checkpoint previously stored with `save()`."""
    with gfile.GFile(path, 'rb') as f:
        ckpt_dict = np.load(f, allow_pickle=False)
        keys, values = zip(*list(ckpt_dict.items()))
    params = checkpoints.convert_pre_linen(recover_tree(keys, values))
    if isinstance(params, flax.core.FrozenDict):
        params = params.unfreeze()
    return params
コード例 #5
0
def reload_jax_checkpoint(agent, bundle_dictionary):
    """Reload variables from a fully specified checkpoint."""
    if bundle_dictionary is not None:
        agent.state = bundle_dictionary['state']
        if isinstance(bundle_dictionary['online_params'], core.FrozenDict):
            agent.online_params = bundle_dictionary['online_params']
        else:  # Load pre-linen checkpoint.
            agent.online_params = core.FrozenDict({
                'params':
                flax_checkpoints.convert_pre_linen(
                    bundle_dictionary['online_params']).unfreeze()
            })
        # We recreate the optimizer with the new online weights.
        # pylint: disable=protected-access
        agent.optimizer = dqn_agent.create_optimizer(agent._optimizer_name)
        # pylint: enable=protected-access
        if 'optimizer_state' in bundle_dictionary:
            agent.optimizer_state = bundle_dictionary['optimizer_state']
        else:
            agent.optimizer_state = agent.optimizer.init(agent.online_params)
        logging.info('Done restoring!')