Beispiel #1
0
from .partitioners import default_partition
from .named_members_polyfill import _named_parameters, _named_buffers
from typing import Callable, List, Dict, Any, Tuple, Optional
from functools import wraps

try:
    from torchdynamo import disable as disable_torchdynamo
except ImportError:

    def disable_torchdynamo(x):
        return x


pytree._register_pytree_node(
    immutable_collections.immutable_list,
    lambda x: (list(x), None),
    lambda x, c: immutable_collections.immutable_list(x),
)
pytree._register_pytree_node(
    immutable_collections.immutable_dict,
    lambda x: (list(x.values()), list(x.keys())),
    lambda x, c: immutable_collections.immutable_dict(
        {key: value
         for key, value in zip(c, x)}),
)

# TODO - move this to PyTorch core. This overrides the pytree implementation for
# dict to maintain parity with Deepmind pytree.
Context = Any

Beispiel #2
0
)

in_dims_t = Union[int, Tuple]
out_dims_t = Union[int, Tuple[int, ...]]


# Temporary OrderedDict registration as pytree
def _odict_flatten(d):
    return list(d.values()), list(d.keys())


def _odict_unflatten(values, context):
    return OrderedDict((key, value) for key, value in zip(context, values))


_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)

# Checks that all args-to-be-batched have the same batch dim size


def _validate_and_get_batch_size(flat_in_dims: List[Optional[int]],
                                 flat_args: List) -> int:
    batch_sizes = [
        arg.size(in_dim) for in_dim, arg in zip(flat_in_dims, flat_args)
        if in_dim is not None
    ]
    if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):
        raise ValueError(
            f'vmap: Expected all tensors to have the same size in the mapped '
            f'dimension, got sizes {batch_sizes} for the mapped dimension')
    return batch_sizes[0]
    dict, ['__delitem__', '__setitem__', 'clear', 'pop', 'popitem', 'update'])
immutable_dict.__reduce__ = lambda self: (immutable_dict,
                                          (iter(self.items()), ))
compatibility(is_backward_compatible=True)(immutable_dict)

# Register immutable collections for PyTree operations


def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
    return list(d.values()), list(d.keys())


def _immutable_dict_unflatten(values: List[Any],
                              context: Context) -> Dict[Any, Any]:
    return immutable_dict({key: value for key, value in zip(context, values)})


def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
    return d, None


def _immutable_list_unflatten(values: List[Any],
                              context: Context) -> List[Any]:
    return immutable_list(values)


_register_pytree_node(immutable_dict, _immutable_dict_flatten,
                      _immutable_dict_unflatten)
_register_pytree_node(immutable_list, _immutable_list_flatten,
                      _immutable_list_unflatten)