Exemple #1
0
def normalize_args_kwargs(target: Callable, args: Tuple[Any], kwargs: Dict[str, Any]):
    """Fill in default values for optional args, which are dependent on the schema."""
    sig = _torchscript_schema_to_signature(target._schema)
    _, new_kwargs = _args_kwargs_to_normalized_args_kwargs(
        sig, args, kwargs, normalize_to_only_use_kwargs=True
    )
    if "self" in new_kwargs:
        new_kwargs["input"] = new_kwargs.pop("self")

    # Flatten lists of args for ops that takes lists, such as torch.cat.
    to_remove = set()
    to_add = {}
    for k, v in new_kwargs.items():
        if isinstance(v, (tuple, list)) and len(v) and isinstance(v[0], torch.Tensor):
            to_remove.add(k)
            for i, vv in enumerate(v):
                to_add[f"{k}_flattened_{i}"] = vv

    for rem in to_remove:
        del new_kwargs[rem]
    new_kwargs.update(**to_add)

    # Sort here in order to have consistency across TS graph and
    # MLIR module.
    sorted_kwargs = dict(sorted(new_kwargs.items()))
    return immutable_collections.immutable_dict(sorted_kwargs)
Exemple #2
0
    def test_copy_it(self):
        d = immutable_dict([(3, 4), (5, 6)])
        l = immutable_list([(3, 4), (5, 6)])

        self.assertEqual(d, deepcopy(d))
        self.assertEqual(l, deepcopy(l))
Exemple #3
0
except ImportError:

    def disable_torchdynamo(x):
        return x


pytree._register_pytree_node(
    immutable_collections.immutable_list,
    lambda x: (list(x), None),
    lambda x, c: immutable_collections.immutable_list(x),
)
pytree._register_pytree_node(
    immutable_collections.immutable_dict,
    lambda x: (list(x.values()), list(x.keys())),
    lambda x, c: immutable_collections.immutable_dict(
        {key: value
         for key, value in zip(c, x)}),
)

# TODO - move this to PyTorch core. This overrides the pytree implementation for
# dict to maintain parity with Deepmind pytree.
Context = Any


def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
    keys = sorted(d.keys())
    values = [d[key] for key in keys]
    return values, keys


def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: