Ejemplo n.º 1
0
  def wrapped(clz):
    data_clz = dataclasses.dataclass(
        frozen=frozen, unsafe_hash=unsafe_hash)(
            clz)
    meta_fields = []
    data_fields = []
    for name, field_info in data_clz.__dataclass_fields__.items():
      is_pytree_node = field_info.metadata.get('pytree_node', True)
      if is_pytree_node:
        data_fields.append(name)
      else:
        meta_fields.append(name)

    def replace(self, **updates):
      """"Returns a new object replacing the specified fields with new values."""
      return dataclasses.replace(self, **updates)

    data_clz.replace = replace

    def iterate_clz(x):
      meta = tuple(getattr(x, name) for name in meta_fields)
      data = tuple(getattr(x, name) for name in data_fields)
      return data, meta

    def clz_from_iterable(meta, data):
      meta_args = tuple(zip(meta_fields, meta))
      data_args = tuple(zip(data_fields, data))
      kwargs = dict(meta_args + data_args)
      return data_clz(**kwargs)

    jax.tree_util.register_pytree_node(data_clz, iterate_clz, clz_from_iterable)

    def to_state_dict(x):
      state_dict = {
          name: serialization.to_state_dict(getattr(x, name))
          for name in data_fields
      }
      return state_dict

    def from_state_dict(x, state):
      """Restore the state of a data class."""
      state = state.copy()  # copy the state so we can pop the restored fields.
      updates = {}
      for name in data_fields:
        if name not in state:
          raise ValueError(f'Missing field {name} in state dict while restoring'
                           f' an instance of {clz.__name__}')
        value = getattr(x, name)
        value_state = state.pop(name)
        updates[name] = serialization.from_state_dict(value, value_state)
      if state:
        names = ','.join(state.keys())
        raise ValueError(f'Unknown field(s) "{names}" in state dict while'
                         f' restoring an instance of {clz.__name__}')
      return x.replace(**updates)

    serialization.register_serialization_state(data_clz, to_state_dict,
                                               from_state_dict)

    return data_clz
Ejemplo n.º 2
0
def register_graph_as_flax_state_dict(cls: Type[T]) -> None:
    def ty_to_state_dict(graph: T) -> Dict[str, Any]:
        edge_dict_of_dicts = defaultdict[Any, Dict[Any, Any]](dict)
        for (source, target), edge_dict in dict(graph.edges).items():
            edge_dict_of_dicts[source][target] = edge_dict
        return {
            'nodes': to_state_dict(dict(graph.nodes)),
            'edges': to_state_dict(dict(edge_dict_of_dicts))
        }

    def ty_from_state_dict(graph: T, state_dict: Dict[str, Any]) -> T:
        retval = type(graph)()
        for node_name, node_dict in state_dict['nodes'].items():
            retval.add_node(
                node_name, **from_state_dict(graph.nodes[node_name],
                                             node_dict))
        for source, target_and_edge_dict in state_dict['edges'].items():
            for target, edge_dict in target_and_edge_dict.items():
                retval.add_edge(
                    source, target,
                    **from_state_dict(graph.edges[source, target], edge_dict))
        return retval

    register_serialization_state(
        cls,
        ty_to_state_dict,  # type: ignore[no-untyped-call]
        ty_from_state_dict)
Ejemplo n.º 3
0
def register_serialization_functions():
    global already_registered  # noqa: W0603
    if not already_registered:
        already_registered = True
        import haiku  # noqa: E0401

        FlatMappingType = type(haiku.data_structures.to_immutable_dict({"ciao": 1}))

        def serialize_flat_mapping(flat_mapping):
            return dict(flat_mapping)

        def deserialize_flat_mapping(flat_mapping, _):
            return haiku.data_structures.to_immutable_dict(flat_mapping)

        serialization.register_serialization_state(
            FlatMappingType,
            serialize_flat_mapping,
            deserialize_flat_mapping,
        )
Ejemplo n.º 4
0
        raise ValueError(
            'A mutable collection should not be transformed by Jax.')
    meta = (type(collection), collection._anchor)  # pylint: disable=protected-access
    return (collection.state, ), meta


def collection_from_iterable(meta, state):
    ty, anchor = meta
    coll = ty(state[0])
    coll._anchor = anchor  # pylint: disable=protected-access
    return coll


# make sure a collection is traced.
jax.tree_util.register_pytree_node(Collection, iterate_collection,
                                   collection_from_iterable)


def _collection_state_dict(collection):
    return serialization._dict_state_dict(collection.as_dict())  # pylint: disable=protected-access


def _collection_from_state_dict(xs, state):
    restored_state = serialization._restore_dict(xs.as_dict(), state)  # pylint: disable=protected-access

    return Collection(restored_state)


serialization.register_serialization_state(Collection, _collection_state_dict,
                                           _collection_from_state_dict)
Ejemplo n.º 5
0
        "sampler_state": serialization.to_state_dict(vstate.sampler_state),
        "n_samples": vstate.n_samples,
        "n_discard": vstate.n_discard,
    }
    return state_dict


def deserialize_MCState(vstate, state_dict):
    import copy

    new_vstate = copy.copy(vstate)
    new_vstate.reset()

    new_vstate.variables = serialization.from_state_dict(
        vstate.variables, state_dict["variables"]
    )
    new_vstate.sampler_state = serialization.from_state_dict(
        vstate.sampler_state, state_dict["sampler_state"]
    )
    new_vstate.n_samples = state_dict["n_samples"]
    new_vstate.n_discard = state_dict["n_discard"]

    return new_vstate


serialization.register_serialization_state(
    MCState,
    serialize_MCState,
    deserialize_MCState,
)
Ejemplo n.º 6
0
    def __dir__(self):
        if isinstance(self._data, dict):
            return list(self._data.keys())
        elif isinstance(self._data, FrozenDict):
            return list(self._data._dict.keys())
        else:
            return []

    def __repr__(self):
        return f'{self._data}'

    def __hash__(self):
        # Note: will only work when wrapping FrozenDict.
        return hash(self._data)

    def copy(self, **kwargs):
        return self._data.__class__(self._data.copy(**kwargs))


tree_util.register_pytree_node(
    DotGetter,
    lambda x: ((x._data, ), ()),  # pylint: disable=protected-access
    lambda _, data: data[0])

# Note: restores as raw dict, intentionally.
serialization.register_serialization_state(
    DotGetter,
    serialization._dict_state_dict,  # pylint: disable=protected-access
    serialization._restore_dict)  # pylint: disable=protected-access
Ejemplo n.º 7
0
  """
  if isinstance(x, FrozenDict):
    # deep copy internal state of a FrozenDict
    # the dict branch would also work here but
    # it is much less performant because jax.tree_map
    # uses an optimized C implementation.
    return jax.tree_map(lambda y: y, x._dict)
  elif isinstance(x, dict):
    ys = {}
    for key, value in x.items():
      ys[key] = unfreeze(value)
    return ys
  else:
    return x


def _frozen_dict_state_dict(xs):
  return {key: serialization.to_state_dict(value) for key, value in xs.items()}


def _restore_frozen_dict(xs, states):
  return FrozenDict(
      {key: serialization.from_state_dict(value, states[key])
       for key, value in xs.items()})


serialization.register_serialization_state(
    FrozenDict,
    _frozen_dict_state_dict,
    _restore_frozen_dict)
Ejemplo n.º 8
0
def dataclass(clz=None, *, init_doc=MISSING, cache_hash=False):
    """
    Decorator creating a NetKet-flavour dataclass.
    This behaves as a flax dataclass, that is a Frozen python dataclass, with a twist!
    See their documentation for standard behaviour.

    The new functionalities added by NetKet are:
     - it is possible to define a method `__pre_init__(*args, **kwargs) -> Tuple[Tuple,Dict]` that processes the arguments
       and keyword arguments provided to the dataclass constructor. This allows to deprecate argument
       names and add some logic to customize the constructors.
       This function should return a tuple of the edited `(args, kwargs)`. If inheriting from other classes it is reccomended
       (though not mandated) to call the same method in parent classes.
       The function should return arguments and keyword arguments that will match the standard dataclass constructor.
       The function can also not be called in some internal cases, so it should not be a strict requirement to execute it.

     - Cached Properties. It is possible to mark properties of a netket dataclass with `@property_cached`. This will make the
       property behave as a standard property, but it's value is cached and reset every time a dataclass is manipulated.
       Cached properties can be part of the flattened pytree or not. See :ref:`netket.utils.struct.property_cached` for more info.

    Optinal Args:
        init_doc: the docstring for the init method. Otherwise it's inherited from `__pre_init__`.
        cache_hash: If True the hash is computed only once and cached. Use if the computation is expensive.

    """

    if clz is None:
        return partial(dataclass, init_doc=init_doc, cache_hash=cache_hash)

    # get globals of the class to put generated methods in there
    _globals = get_class_globals(clz)
    _globals["Uninitialized"] = Uninitialized
    # proces all cached properties
    process_cached_properties(clz, globals=_globals)
    # create the dataclass
    data_clz = dataclasses.dataclass(frozen=True)(clz)
    purge_cache_fields(data_clz)
    # attach the custom preprocessing of init arguments
    attach_preprocess_init(
        data_clz, globals=_globals, init_doc=init_doc, cache_hash=cache_hash
    )
    if cache_hash:
        replace_hash_method(data_clz, globals=_globals)

    # flax stuff: identify states
    meta_fields = []
    data_fields = []
    for name, field_info in getattr(data_clz, _FIELDS, {}).items():
        is_pytree_node = field_info.metadata.get("pytree_node", True)
        if is_pytree_node:
            data_fields.append(name)
        else:
            meta_fields.append(name)

    # List the cache fields
    cache_fields = []
    for name, cp in getattr(data_clz, _CACHES, {}).items():
        cache_fields.append(cp.cache_name)
        # they count as meta fields
        meta_fields.append(cp.cache_name)

    def replace(self, **updates):
        """"Returns a new object replacing the specified fields with new values."""
        # reset cached fields
        for name in cache_fields:
            updates[name] = Uninitialized

        return dataclasses.replace(self, **updates)

    data_clz.replace = replace

    # support for jax pytree flattening unflattening
    def iterate_clz(x):
        meta = tuple(getattr(x, name) for name in meta_fields)
        data = tuple(getattr(x, name) for name in data_fields)
        return data, meta

    def clz_from_iterable(meta, data):
        meta_args = tuple(zip(meta_fields, meta))
        data_args = tuple(zip(data_fields, data))
        kwargs = dict(meta_args + data_args)
        return data_clz(__skip_preprocess=True, **kwargs)

    jax.tree_util.register_pytree_node(data_clz, iterate_clz, clz_from_iterable)

    # flax serialization
    skip_serialize_fields = []
    for name, field_info in data_clz.__dataclass_fields__.items():
        if not field_info.metadata.get("serialize", True):
            skip_serialize_fields.append(name)

    def to_state_dict(x):
        state_dict = {
            name: serialization.to_state_dict(getattr(x, name))
            for name in data_fields
            if name not in skip_serialize_fields
        }
        return state_dict

    def from_state_dict(x, state):
        """Restore the state of a data class."""
        state = state.copy()  # copy the state so we can pop the restored fields.
        updates = {}
        for name in data_fields:
            if name not in skip_serialize_fields:
                if name not in state:
                    raise ValueError(
                        f"Missing field {name} in state dict while restoring"
                        f" an instance of {clz.__name__}"
                    )
                value = getattr(x, name)
                value_state = state.pop(name)
                updates[name] = serialization.from_state_dict(value, value_state)
        if state:
            names = ",".join(state.keys())
            raise ValueError(
                f'Unknown field(s) "{names}" in state dict while'
                f" restoring an instance of {clz.__name__}"
            )
        return x.replace(**updates)

    serialization.register_serialization_state(data_clz, to_state_dict, from_state_dict)

    return data_clz
Ejemplo n.º 9
0
def _register_dataclass_pytree(
    cls: Type[T],
    static_fields: Sequence[str] = [],
    make_immutable: bool = True,
) -> Type[T]:

    assert dataclasses.is_dataclass(cls)

    # Respect static field registration from superclasses
    static_fields_list = list(static_fields)
    del static_fields

    for parent_class in filter(lambda x: x in _registered_static_fields,
                               cls.mro()):
        static_fields_list.extend(_registered_static_fields[parent_class])

    static_fields_set = set(static_fields_list)
    assert len(static_fields_list) == len(
        static_fields_set), "Found repeated field names!"

    _registered_static_fields[cls] = static_fields_set

    # Get a list of fields in our dataclass
    field: dataclasses.Field
    field_names = [field.name for field in dataclasses.fields(cls)]
    children_fields = [
        name for name in field_names if name not in static_fields_set
    ]
    assert set(field_names) == set(children_fields) | set(
        static_fields_set), "Field name anomoly; check static fields list!"

    # Define flatten, unflatten operations: this simple converts our dataclass to a list
    # of fields.
    def _flatten(obj):
        return [getattr(obj, key) for key in children_fields
                ], tuple(getattr(obj, key) for key in static_fields_set)

    def _unflatten(treedef, children):
        return cls(
            **dict(zip(children_fields, children)),
            **dict(zip(static_fields_set, treedef)),
        )

        # Alternative:
        #     return dataclasses.replace(
        #         cls.__new__(cls),
        #         **dict(zip(children_fields, children)),
        #         **dict(zip(static_fields_set, treedef)),
        #     )

    jax.tree_util.register_pytree_node(cls, _flatten, _unflatten)

    # Serialization: this is mostly copied from `flax.struct.dataclass`
    def _to_state_dict(x: T):
        state_dict = {
            name: serialization.to_state_dict(getattr(x, name))
            for name in field_names
        }
        return state_dict

    def _from_state_dict(x: T, state: Dict):
        state = state.copy(
        )  # copy the state so we can pop the restored fields.
        updates = {}
        for name in field_names:
            if name not in state:
                raise ValueError(
                    f"Missing field {name} in state dict while restoring"
                    f" an instance of {cls.__name__}")
            value = getattr(x, name)
            value_state = state.pop(name)
            updates[name] = serialization.from_state_dict(value, value_state)
        if state:
            names = ",".join(state.keys())
            raise ValueError(f'Unknown field(s) "{names}" in state dict while'
                             f" restoring an instance of {cls.__name__}")
        return dataclasses.replace(x, **updates)

    serialization.register_serialization_state(cls, _to_state_dict,
                                               _from_state_dict)

    # Make dataclass immutable after __init__ is called
    # Similar to dataclasses.dataclass(frozen=True), but a bit friendlier for custom
    # __init__ functions
    if make_immutable:
        original_init = cls.__init__ if hasattr(cls, "__init__") else None

        def disabled_setattr(*args, **kwargs):
            raise dataclasses.FrozenInstanceError(
                "Dataclass registered as PyTrees is immutable!")

        def new_init(self, *args, **kwargs):
            cls.__setattr__ = object.__setattr__
            if original_init is not None:
                original_init(self, *args, **kwargs)
            cls.__setattr__ = disabled_setattr

        cls.__setattr__ = disabled_setattr  # type: ignore
        cls.__init__ = new_init  # type: ignore

    return cls
Ejemplo n.º 10
0
    def __str__(self):
        return "ExactState(" + "hilbert = {}, ".format(self.hilbert)


# serialization


def serialize_ExactState(vstate):
    state_dict = {
        "variables": serialization.to_state_dict(vstate.variables),
    }
    return state_dict


def deserialize_ExactState(vstate, state_dict):
    import copy

    new_vstate = copy.copy(vstate)
    new_vstate.reset()

    new_vstate.variables = serialization.from_state_dict(
        vstate.variables, state_dict["variables"])
    return new_vstate


serialization.register_serialization_state(
    ExactState,
    serialize_ExactState,
    deserialize_ExactState,
)
Ejemplo n.º 11
0
    Args:
      state: the class state.
      state_dict: the state dict containing the desired new state of the object.

    Returns:
      The restored class object.
    """

        state = serialization.from_state_dict(state, state_dict['state'])
        return self.replace(state=state)


serialization.register_serialization_state(
    TrainingMetricsGrabber,
    TrainingMetricsGrabber.state_dict,
    TrainingMetricsGrabber.restore_state,
    override=True)


def run_in_parallel(function, list_of_kwargs_to_function, num_workers):
    """Run a function on a list of kwargs in parallel with ThreadPoolExecutor.

  Adapted from code by mlbileschi.
  Args:
    function: a function.
    list_of_kwargs_to_function: list of dictionary from string to argument
      value. These will be passed into `function` as kwargs.
    num_workers: int.

  Returns:
Ejemplo n.º 12
0
    state_dict: a state dict containing the desired new state of the object.

  Returns:
    The restored class object.
  """
  checkpoint_state.pytree = serialization.from_state_dict(
      checkpoint_state.pytree, state_dict['pytree'])
  checkpoint_state.pystate = serialization.from_state_dict(
      checkpoint_state.pystate, state_dict['pystate'])
  return checkpoint_state


# Note that this will only be used if we call
# `flax_checkpoints.restore_checkpoint` with the `target` arg set to a
# CheckpointState object to be filled in with values.
serialization.register_serialization_state(
    CheckpointState, _ckpt_state_dict, _ckpt_restore_state, override=True)


def _save_checkpoint_background_catch_error(*args, **kwargs):
  """Call save_checkpoint with provided args, store exception if any."""
  global _save_checkpoint_background_error
  try:
    save_checkpoint(*args, **kwargs)
    _save_checkpoint_background_error = None
  except BaseException as err:  # pylint: disable=broad-except
    logging.exception('Error while saving checkpoint in background.')
    _save_checkpoint_background_error = err


def wait_for_checkpoint_save():
  """Wait until last checkpoint save (if any) to finish."""