def wrapped(clz): data_clz = dataclasses.dataclass( frozen=frozen, unsafe_hash=unsafe_hash)( clz) meta_fields = [] data_fields = [] for name, field_info in data_clz.__dataclass_fields__.items(): is_pytree_node = field_info.metadata.get('pytree_node', True) if is_pytree_node: data_fields.append(name) else: meta_fields.append(name) def replace(self, **updates): """"Returns a new object replacing the specified fields with new values.""" return dataclasses.replace(self, **updates) data_clz.replace = replace def iterate_clz(x): meta = tuple(getattr(x, name) for name in meta_fields) data = tuple(getattr(x, name) for name in data_fields) return data, meta def clz_from_iterable(meta, data): meta_args = tuple(zip(meta_fields, meta)) data_args = tuple(zip(data_fields, data)) kwargs = dict(meta_args + data_args) return data_clz(**kwargs) jax.tree_util.register_pytree_node(data_clz, iterate_clz, clz_from_iterable) def to_state_dict(x): state_dict = { name: serialization.to_state_dict(getattr(x, name)) for name in data_fields } return state_dict def from_state_dict(x, state): """Restore the state of a data class.""" state = state.copy() # copy the state so we can pop the restored fields. updates = {} for name in data_fields: if name not in state: raise ValueError(f'Missing field {name} in state dict while restoring' f' an instance of {clz.__name__}') value = getattr(x, name) value_state = state.pop(name) updates[name] = serialization.from_state_dict(value, value_state) if state: names = ','.join(state.keys()) raise ValueError(f'Unknown field(s) "{names}" in state dict while' f' restoring an instance of {clz.__name__}') return x.replace(**updates) serialization.register_serialization_state(data_clz, to_state_dict, from_state_dict) return data_clz
def register_graph_as_flax_state_dict(cls: Type[T]) -> None: def ty_to_state_dict(graph: T) -> Dict[str, Any]: edge_dict_of_dicts = defaultdict[Any, Dict[Any, Any]](dict) for (source, target), edge_dict in dict(graph.edges).items(): edge_dict_of_dicts[source][target] = edge_dict return { 'nodes': to_state_dict(dict(graph.nodes)), 'edges': to_state_dict(dict(edge_dict_of_dicts)) } def ty_from_state_dict(graph: T, state_dict: Dict[str, Any]) -> T: retval = type(graph)() for node_name, node_dict in state_dict['nodes'].items(): retval.add_node( node_name, **from_state_dict(graph.nodes[node_name], node_dict)) for source, target_and_edge_dict in state_dict['edges'].items(): for target, edge_dict in target_and_edge_dict.items(): retval.add_edge( source, target, **from_state_dict(graph.edges[source, target], edge_dict)) return retval register_serialization_state( cls, ty_to_state_dict, # type: ignore[no-untyped-call] ty_from_state_dict)
def register_serialization_functions(): global already_registered # noqa: W0603 if not already_registered: already_registered = True import haiku # noqa: E0401 FlatMappingType = type(haiku.data_structures.to_immutable_dict({"ciao": 1})) def serialize_flat_mapping(flat_mapping): return dict(flat_mapping) def deserialize_flat_mapping(flat_mapping, _): return haiku.data_structures.to_immutable_dict(flat_mapping) serialization.register_serialization_state( FlatMappingType, serialize_flat_mapping, deserialize_flat_mapping, )
raise ValueError( 'A mutable collection should not be transformed by Jax.') meta = (type(collection), collection._anchor) # pylint: disable=protected-access return (collection.state, ), meta def collection_from_iterable(meta, state): ty, anchor = meta coll = ty(state[0]) coll._anchor = anchor # pylint: disable=protected-access return coll # make sure a collection is traced. jax.tree_util.register_pytree_node(Collection, iterate_collection, collection_from_iterable) def _collection_state_dict(collection): return serialization._dict_state_dict(collection.as_dict()) # pylint: disable=protected-access def _collection_from_state_dict(xs, state): restored_state = serialization._restore_dict(xs.as_dict(), state) # pylint: disable=protected-access return Collection(restored_state) serialization.register_serialization_state(Collection, _collection_state_dict, _collection_from_state_dict)
"sampler_state": serialization.to_state_dict(vstate.sampler_state), "n_samples": vstate.n_samples, "n_discard": vstate.n_discard, } return state_dict def deserialize_MCState(vstate, state_dict): import copy new_vstate = copy.copy(vstate) new_vstate.reset() new_vstate.variables = serialization.from_state_dict( vstate.variables, state_dict["variables"] ) new_vstate.sampler_state = serialization.from_state_dict( vstate.sampler_state, state_dict["sampler_state"] ) new_vstate.n_samples = state_dict["n_samples"] new_vstate.n_discard = state_dict["n_discard"] return new_vstate serialization.register_serialization_state( MCState, serialize_MCState, deserialize_MCState, )
def __dir__(self): if isinstance(self._data, dict): return list(self._data.keys()) elif isinstance(self._data, FrozenDict): return list(self._data._dict.keys()) else: return [] def __repr__(self): return f'{self._data}' def __hash__(self): # Note: will only work when wrapping FrozenDict. return hash(self._data) def copy(self, **kwargs): return self._data.__class__(self._data.copy(**kwargs)) tree_util.register_pytree_node( DotGetter, lambda x: ((x._data, ), ()), # pylint: disable=protected-access lambda _, data: data[0]) # Note: restores as raw dict, intentionally. serialization.register_serialization_state( DotGetter, serialization._dict_state_dict, # pylint: disable=protected-access serialization._restore_dict) # pylint: disable=protected-access
""" if isinstance(x, FrozenDict): # deep copy internal state of a FrozenDict # the dict branch would also work here but # it is much less performant because jax.tree_map # uses an optimized C implementation. return jax.tree_map(lambda y: y, x._dict) elif isinstance(x, dict): ys = {} for key, value in x.items(): ys[key] = unfreeze(value) return ys else: return x def _frozen_dict_state_dict(xs): return {key: serialization.to_state_dict(value) for key, value in xs.items()} def _restore_frozen_dict(xs, states): return FrozenDict( {key: serialization.from_state_dict(value, states[key]) for key, value in xs.items()}) serialization.register_serialization_state( FrozenDict, _frozen_dict_state_dict, _restore_frozen_dict)
def dataclass(clz=None, *, init_doc=MISSING, cache_hash=False): """ Decorator creating a NetKet-flavour dataclass. This behaves as a flax dataclass, that is a Frozen python dataclass, with a twist! See their documentation for standard behaviour. The new functionalities added by NetKet are: - it is possible to define a method `__pre_init__(*args, **kwargs) -> Tuple[Tuple,Dict]` that processes the arguments and keyword arguments provided to the dataclass constructor. This allows to deprecate argument names and add some logic to customize the constructors. This function should return a tuple of the edited `(args, kwargs)`. If inheriting from other classes it is reccomended (though not mandated) to call the same method in parent classes. The function should return arguments and keyword arguments that will match the standard dataclass constructor. The function can also not be called in some internal cases, so it should not be a strict requirement to execute it. - Cached Properties. It is possible to mark properties of a netket dataclass with `@property_cached`. This will make the property behave as a standard property, but it's value is cached and reset every time a dataclass is manipulated. Cached properties can be part of the flattened pytree or not. See :ref:`netket.utils.struct.property_cached` for more info. Optinal Args: init_doc: the docstring for the init method. Otherwise it's inherited from `__pre_init__`. cache_hash: If True the hash is computed only once and cached. Use if the computation is expensive. """ if clz is None: return partial(dataclass, init_doc=init_doc, cache_hash=cache_hash) # get globals of the class to put generated methods in there _globals = get_class_globals(clz) _globals["Uninitialized"] = Uninitialized # proces all cached properties process_cached_properties(clz, globals=_globals) # create the dataclass data_clz = dataclasses.dataclass(frozen=True)(clz) purge_cache_fields(data_clz) # attach the custom preprocessing of init arguments attach_preprocess_init( data_clz, globals=_globals, init_doc=init_doc, cache_hash=cache_hash ) if cache_hash: replace_hash_method(data_clz, globals=_globals) # flax stuff: identify states meta_fields = [] data_fields = [] for name, field_info in getattr(data_clz, _FIELDS, {}).items(): is_pytree_node = field_info.metadata.get("pytree_node", True) if is_pytree_node: data_fields.append(name) else: meta_fields.append(name) # List the cache fields cache_fields = [] for name, cp in getattr(data_clz, _CACHES, {}).items(): cache_fields.append(cp.cache_name) # they count as meta fields meta_fields.append(cp.cache_name) def replace(self, **updates): """"Returns a new object replacing the specified fields with new values.""" # reset cached fields for name in cache_fields: updates[name] = Uninitialized return dataclasses.replace(self, **updates) data_clz.replace = replace # support for jax pytree flattening unflattening def iterate_clz(x): meta = tuple(getattr(x, name) for name in meta_fields) data = tuple(getattr(x, name) for name in data_fields) return data, meta def clz_from_iterable(meta, data): meta_args = tuple(zip(meta_fields, meta)) data_args = tuple(zip(data_fields, data)) kwargs = dict(meta_args + data_args) return data_clz(__skip_preprocess=True, **kwargs) jax.tree_util.register_pytree_node(data_clz, iterate_clz, clz_from_iterable) # flax serialization skip_serialize_fields = [] for name, field_info in data_clz.__dataclass_fields__.items(): if not field_info.metadata.get("serialize", True): skip_serialize_fields.append(name) def to_state_dict(x): state_dict = { name: serialization.to_state_dict(getattr(x, name)) for name in data_fields if name not in skip_serialize_fields } return state_dict def from_state_dict(x, state): """Restore the state of a data class.""" state = state.copy() # copy the state so we can pop the restored fields. updates = {} for name in data_fields: if name not in skip_serialize_fields: if name not in state: raise ValueError( f"Missing field {name} in state dict while restoring" f" an instance of {clz.__name__}" ) value = getattr(x, name) value_state = state.pop(name) updates[name] = serialization.from_state_dict(value, value_state) if state: names = ",".join(state.keys()) raise ValueError( f'Unknown field(s) "{names}" in state dict while' f" restoring an instance of {clz.__name__}" ) return x.replace(**updates) serialization.register_serialization_state(data_clz, to_state_dict, from_state_dict) return data_clz
def _register_dataclass_pytree( cls: Type[T], static_fields: Sequence[str] = [], make_immutable: bool = True, ) -> Type[T]: assert dataclasses.is_dataclass(cls) # Respect static field registration from superclasses static_fields_list = list(static_fields) del static_fields for parent_class in filter(lambda x: x in _registered_static_fields, cls.mro()): static_fields_list.extend(_registered_static_fields[parent_class]) static_fields_set = set(static_fields_list) assert len(static_fields_list) == len( static_fields_set), "Found repeated field names!" _registered_static_fields[cls] = static_fields_set # Get a list of fields in our dataclass field: dataclasses.Field field_names = [field.name for field in dataclasses.fields(cls)] children_fields = [ name for name in field_names if name not in static_fields_set ] assert set(field_names) == set(children_fields) | set( static_fields_set), "Field name anomoly; check static fields list!" # Define flatten, unflatten operations: this simple converts our dataclass to a list # of fields. def _flatten(obj): return [getattr(obj, key) for key in children_fields ], tuple(getattr(obj, key) for key in static_fields_set) def _unflatten(treedef, children): return cls( **dict(zip(children_fields, children)), **dict(zip(static_fields_set, treedef)), ) # Alternative: # return dataclasses.replace( # cls.__new__(cls), # **dict(zip(children_fields, children)), # **dict(zip(static_fields_set, treedef)), # ) jax.tree_util.register_pytree_node(cls, _flatten, _unflatten) # Serialization: this is mostly copied from `flax.struct.dataclass` def _to_state_dict(x: T): state_dict = { name: serialization.to_state_dict(getattr(x, name)) for name in field_names } return state_dict def _from_state_dict(x: T, state: Dict): state = state.copy( ) # copy the state so we can pop the restored fields. updates = {} for name in field_names: if name not in state: raise ValueError( f"Missing field {name} in state dict while restoring" f" an instance of {cls.__name__}") value = getattr(x, name) value_state = state.pop(name) updates[name] = serialization.from_state_dict(value, value_state) if state: names = ",".join(state.keys()) raise ValueError(f'Unknown field(s) "{names}" in state dict while' f" restoring an instance of {cls.__name__}") return dataclasses.replace(x, **updates) serialization.register_serialization_state(cls, _to_state_dict, _from_state_dict) # Make dataclass immutable after __init__ is called # Similar to dataclasses.dataclass(frozen=True), but a bit friendlier for custom # __init__ functions if make_immutable: original_init = cls.__init__ if hasattr(cls, "__init__") else None def disabled_setattr(*args, **kwargs): raise dataclasses.FrozenInstanceError( "Dataclass registered as PyTrees is immutable!") def new_init(self, *args, **kwargs): cls.__setattr__ = object.__setattr__ if original_init is not None: original_init(self, *args, **kwargs) cls.__setattr__ = disabled_setattr cls.__setattr__ = disabled_setattr # type: ignore cls.__init__ = new_init # type: ignore return cls
def __str__(self): return "ExactState(" + "hilbert = {}, ".format(self.hilbert) # serialization def serialize_ExactState(vstate): state_dict = { "variables": serialization.to_state_dict(vstate.variables), } return state_dict def deserialize_ExactState(vstate, state_dict): import copy new_vstate = copy.copy(vstate) new_vstate.reset() new_vstate.variables = serialization.from_state_dict( vstate.variables, state_dict["variables"]) return new_vstate serialization.register_serialization_state( ExactState, serialize_ExactState, deserialize_ExactState, )
Args: state: the class state. state_dict: the state dict containing the desired new state of the object. Returns: The restored class object. """ state = serialization.from_state_dict(state, state_dict['state']) return self.replace(state=state) serialization.register_serialization_state( TrainingMetricsGrabber, TrainingMetricsGrabber.state_dict, TrainingMetricsGrabber.restore_state, override=True) def run_in_parallel(function, list_of_kwargs_to_function, num_workers): """Run a function on a list of kwargs in parallel with ThreadPoolExecutor. Adapted from code by mlbileschi. Args: function: a function. list_of_kwargs_to_function: list of dictionary from string to argument value. These will be passed into `function` as kwargs. num_workers: int. Returns:
state_dict: a state dict containing the desired new state of the object. Returns: The restored class object. """ checkpoint_state.pytree = serialization.from_state_dict( checkpoint_state.pytree, state_dict['pytree']) checkpoint_state.pystate = serialization.from_state_dict( checkpoint_state.pystate, state_dict['pystate']) return checkpoint_state # Note that this will only be used if we call # `flax_checkpoints.restore_checkpoint` with the `target` arg set to a # CheckpointState object to be filled in with values. serialization.register_serialization_state( CheckpointState, _ckpt_state_dict, _ckpt_restore_state, override=True) def _save_checkpoint_background_catch_error(*args, **kwargs): """Call save_checkpoint with provided args, store exception if any.""" global _save_checkpoint_background_error try: save_checkpoint(*args, **kwargs) _save_checkpoint_background_error = None except BaseException as err: # pylint: disable=broad-except logging.exception('Error while saving checkpoint in background.') _save_checkpoint_background_error = err def wait_for_checkpoint_save(): """Wait until last checkpoint save (if any) to finish."""