Example #1
0
def load_checkpoint(state, ckpt_dir, allow_missing=False):
    """Load checkpoint from directory into optimizer.

  Args:
    state: Flax optimizer state.
    ckpt_dir: Directory to load checkpoint from.
    allow_missing: Allows missing keys in checkpoint.

  Returns:
    deserialized optimizer.
  """
    del allow_missing

    ckpt = checkpoints.restore_checkpoint(ckpt_dir, target=None)
    print("-- load called --")
    if ckpt is None:
        logging.info("No checkpoint in %s.", ckpt_dir)
        return state

    print("Loading model - is not None")

    optimizer = _load_optimizer(state.optimizer, ckpt["optimizer"])

    keys = [key for key in state.keys() if key != "optimizer"]
    state_dict = {
        k: serialization.from_state_dict(getattr(state, k), ckpt[k])
        for k in keys
    }

    state_dict["optimizer"] = optimizer

    return state.replace(**state_dict)
Example #2
0
def _ckpt_restore_state(checkpoint_state, state_dict):
  """Restore the state from the state dict.

  Allows for checkpointing the class object.

  Args:
    checkpoint_state: an instance of a CheckpointState.
    state_dict: a state dict containing the desired new state of the object.

  Returns:
    The restored class object.
  """
  checkpoint_state.pytree = serialization.from_state_dict(
      checkpoint_state.pytree, state_dict['pytree'])
  checkpoint_state.pystate = serialization.from_state_dict(
      checkpoint_state.pystate, state_dict['pystate'])
  return checkpoint_state
Example #3
0
def deserialize_ExactState(vstate, state_dict):
    import copy

    new_vstate = copy.copy(vstate)
    new_vstate.reset()

    new_vstate.variables = serialization.from_state_dict(
        vstate.variables, state_dict["variables"])
    return new_vstate
Example #4
0
def _load_optimizer(optimizer, ckpt, allow_missing=False):
    """Loads the optimizer from the state dict."""
    init_keys = set(dict(tree.flatten_with_path(ckpt["target"])))
    model_keys = set(dict(tree.flatten_with_path(optimizer.target)))
    missing_in_model = init_keys.difference(model_keys)
    missing_in_init = model_keys.difference(init_keys)
    missing = model_keys.symmetric_difference(init_keys)
    print("init - model keys: %s", str(missing_in_model))
    print("model - init keys: %s", str(missing_in_init))
    print("difference: %s", str(missing))

    if not allow_missing:
        if missing_in_init:
            raise ValueError(
                "Checkpoints must match exactly if `allow_missing=False`. "
                "Checkpoint missing %s" % str(missing_in_init))

    for param_path in missing_in_init:

        def get_path(d, path):
            print(path)
            print("get")
            for k in path:
                print(k)
                d = d[k]
            return d

        def set_path(d, path, value):
            print("set")
            for k in path[:-1]:
                if k not in d:
                    d[k] = dict()
                d = d[k]
            k = path[-1]
            if k in d:
                if value.shape != d[k].shape:
                    raise ValueError("Shape mismatch: %s" % str(
                        (k, value.shape, d[k].shape)))
            d[k] = value
            return d

        target_param = get_path(optimizer.target, param_path)
        set_path(ckpt["target"], param_path, target_param)

        try:
            target_opt_state = get_path(optimizer.state.param_states,
                                        param_path)
            target_opt_state = serialization.to_state_dict(target_opt_state)
            set_path(ckpt["state"]["param_states"], param_path,
                     target_opt_state)
        except TypeError:
            print(
                f"unable to restore state for {param_path}. Resetting state.")
            ckpt["state"] = serialization.to_state_dict(optimizer.state)

    return serialization.from_state_dict(optimizer, ckpt)
Example #5
0
  def restore_state(self, state_dict):
    """Restore parameter and optimizer state from state dictionary.

    Adapted from
    https://github.com/google-research/t5x/blob/main/t5x/optimizers.py. Includes
    support to handle `optax.EmptyState`.

    Args:
      state_dict: Contains desired new parameters and optimizer state

    Returns:
      Updated train state.
    """
    params = serialization.from_state_dict(self.params, state_dict["params"])

    # Get all the possible keys in the reference optimizer state.
    flat_ref_opt_state_dict = traverse_util.flatten_dict(
        serialization.to_state_dict(self.opt_state),
        keep_empty_nodes=True,
        sep="/")

    flat_src_opt_state_dict = dict(
        traverse_util.flatten_dict(state_dict["opt_state"], sep="/"))
    # Adding the empty paths back to flat_src_opt_state_dict.
    for k, v in flat_ref_opt_state_dict.items():
      if k in flat_src_opt_state_dict:
        continue
      # The key is not in the input state dict, presumably because it
      # corresponds to an empty dict.
      if v != traverse_util.empty_node:
        raise ValueError(
            f"Failed to restore optimizer state, path {k} is not present "
            "in the input optimizer state dict.")
      flat_src_opt_state_dict[k] = v

    # Restore state from the enhanced state dict.
    opt_state = serialization.from_state_dict(
        self.opt_state,
        traverse_util.unflatten_dict(flat_src_opt_state_dict, sep="/"))
    return self.replace(params=params, opt_state=opt_state)
Example #6
0
    def restore_state(self, state, state_dict):
        """Restore the state from the state dict.

    Allows for checkpointing the class object.

    Args:
      state: the class state.
      state_dict: the state dict containing the desired new state of the object.

    Returns:
      The restored class object.
    """

        state = serialization.from_state_dict(state, state_dict['state'])
        return self.replace(state=state)
Example #7
0
 def _from_state_dict(x: T, state: Dict):
     state = state.copy(
     )  # copy the state so we can pop the restored fields.
     updates = {}
     for name in field_names:
         if name not in state:
             raise ValueError(
                 f"Missing field {name} in state dict while restoring"
                 f" an instance of {cls.__name__}")
         value = getattr(x, name)
         value_state = state.pop(name)
         updates[name] = serialization.from_state_dict(value, value_state)
     if state:
         names = ",".join(state.keys())
         raise ValueError(f'Unknown field(s) "{names}" in state dict while'
                          f" restoring an instance of {cls.__name__}")
     return dataclasses.replace(x, **updates)
Example #8
0
 def test_namedtuple_restore_legacy(self):
     foo_class = collections.namedtuple('Foo', 'a b c')
     x1 = foo_class(a=1, b=2, c=3)
     legacy_encoding = {
         'name': 'Foo',
         'fields': {
             '0': 'a',
             '1': 'b',
             '2': 'c'
         },
         'values': {
             '0': 1,
             '1': 2,
             '2': 3
         },
     }
     x2 = foo_class(a=0, b=0, c=0)
     restored_x1 = serialization.from_state_dict(x2, legacy_encoding)
     self.assertEqual(type(x1), type(restored_x1))
     self.assertEqual(x1, restored_x1)