def load_checkpoint(state, ckpt_dir, allow_missing=False): """Load checkpoint from directory into optimizer. Args: state: Flax optimizer state. ckpt_dir: Directory to load checkpoint from. allow_missing: Allows missing keys in checkpoint. Returns: deserialized optimizer. """ del allow_missing ckpt = checkpoints.restore_checkpoint(ckpt_dir, target=None) print("-- load called --") if ckpt is None: logging.info("No checkpoint in %s.", ckpt_dir) return state print("Loading model - is not None") optimizer = _load_optimizer(state.optimizer, ckpt["optimizer"]) keys = [key for key in state.keys() if key != "optimizer"] state_dict = { k: serialization.from_state_dict(getattr(state, k), ckpt[k]) for k in keys } state_dict["optimizer"] = optimizer return state.replace(**state_dict)
def _ckpt_restore_state(checkpoint_state, state_dict): """Restore the state from the state dict. Allows for checkpointing the class object. Args: checkpoint_state: an instance of a CheckpointState. state_dict: a state dict containing the desired new state of the object. Returns: The restored class object. """ checkpoint_state.pytree = serialization.from_state_dict( checkpoint_state.pytree, state_dict['pytree']) checkpoint_state.pystate = serialization.from_state_dict( checkpoint_state.pystate, state_dict['pystate']) return checkpoint_state
def deserialize_ExactState(vstate, state_dict): import copy new_vstate = copy.copy(vstate) new_vstate.reset() new_vstate.variables = serialization.from_state_dict( vstate.variables, state_dict["variables"]) return new_vstate
def _load_optimizer(optimizer, ckpt, allow_missing=False): """Loads the optimizer from the state dict.""" init_keys = set(dict(tree.flatten_with_path(ckpt["target"]))) model_keys = set(dict(tree.flatten_with_path(optimizer.target))) missing_in_model = init_keys.difference(model_keys) missing_in_init = model_keys.difference(init_keys) missing = model_keys.symmetric_difference(init_keys) print("init - model keys: %s", str(missing_in_model)) print("model - init keys: %s", str(missing_in_init)) print("difference: %s", str(missing)) if not allow_missing: if missing_in_init: raise ValueError( "Checkpoints must match exactly if `allow_missing=False`. " "Checkpoint missing %s" % str(missing_in_init)) for param_path in missing_in_init: def get_path(d, path): print(path) print("get") for k in path: print(k) d = d[k] return d def set_path(d, path, value): print("set") for k in path[:-1]: if k not in d: d[k] = dict() d = d[k] k = path[-1] if k in d: if value.shape != d[k].shape: raise ValueError("Shape mismatch: %s" % str( (k, value.shape, d[k].shape))) d[k] = value return d target_param = get_path(optimizer.target, param_path) set_path(ckpt["target"], param_path, target_param) try: target_opt_state = get_path(optimizer.state.param_states, param_path) target_opt_state = serialization.to_state_dict(target_opt_state) set_path(ckpt["state"]["param_states"], param_path, target_opt_state) except TypeError: print( f"unable to restore state for {param_path}. Resetting state.") ckpt["state"] = serialization.to_state_dict(optimizer.state) return serialization.from_state_dict(optimizer, ckpt)
def restore_state(self, state_dict): """Restore parameter and optimizer state from state dictionary. Adapted from https://github.com/google-research/t5x/blob/main/t5x/optimizers.py. Includes support to handle `optax.EmptyState`. Args: state_dict: Contains desired new parameters and optimizer state Returns: Updated train state. """ params = serialization.from_state_dict(self.params, state_dict["params"]) # Get all the possible keys in the reference optimizer state. flat_ref_opt_state_dict = traverse_util.flatten_dict( serialization.to_state_dict(self.opt_state), keep_empty_nodes=True, sep="/") flat_src_opt_state_dict = dict( traverse_util.flatten_dict(state_dict["opt_state"], sep="/")) # Adding the empty paths back to flat_src_opt_state_dict. for k, v in flat_ref_opt_state_dict.items(): if k in flat_src_opt_state_dict: continue # The key is not in the input state dict, presumably because it # corresponds to an empty dict. if v != traverse_util.empty_node: raise ValueError( f"Failed to restore optimizer state, path {k} is not present " "in the input optimizer state dict.") flat_src_opt_state_dict[k] = v # Restore state from the enhanced state dict. opt_state = serialization.from_state_dict( self.opt_state, traverse_util.unflatten_dict(flat_src_opt_state_dict, sep="/")) return self.replace(params=params, opt_state=opt_state)
def restore_state(self, state, state_dict): """Restore the state from the state dict. Allows for checkpointing the class object. Args: state: the class state. state_dict: the state dict containing the desired new state of the object. Returns: The restored class object. """ state = serialization.from_state_dict(state, state_dict['state']) return self.replace(state=state)
def _from_state_dict(x: T, state: Dict): state = state.copy( ) # copy the state so we can pop the restored fields. updates = {} for name in field_names: if name not in state: raise ValueError( f"Missing field {name} in state dict while restoring" f" an instance of {cls.__name__}") value = getattr(x, name) value_state = state.pop(name) updates[name] = serialization.from_state_dict(value, value_state) if state: names = ",".join(state.keys()) raise ValueError(f'Unknown field(s) "{names}" in state dict while' f" restoring an instance of {cls.__name__}") return dataclasses.replace(x, **updates)
def test_namedtuple_restore_legacy(self): foo_class = collections.namedtuple('Foo', 'a b c') x1 = foo_class(a=1, b=2, c=3) legacy_encoding = { 'name': 'Foo', 'fields': { '0': 'a', '1': 'b', '2': 'c' }, 'values': { '0': 1, '1': 2, '2': 3 }, } x2 = foo_class(a=0, b=0, c=0) restored_x1 = serialization.from_state_dict(x2, legacy_encoding) self.assertEqual(type(x1), type(restored_x1)) self.assertEqual(x1, restored_x1)