def __init__(self, start=None, forced_prefix="", key_translate_func=make_identifier_from_name, name_generator=None): if start is None: start = {} self._dict = dict(start) if name_generator is None: name_generator = UniqueNameGenerator(forced_prefix=forced_prefix) else: if forced_prefix: raise TypeError("passing 'forced_prefix' is not allowed when " "passing a pre-existing name generator") for existing_name in start.values(): if existing_name.startswith(name_generator.forced_prefix): name_generator.add_name(existing_name) self._generator = _KeyTranslatingUniqueNameGeneratorWrapper( name_generator, key_translate_func)
class Namespace(Mapping[str, "Array"]): # Possible future extension: .parent attribute r""" Represents a mapping from :term:`identifier` strings to :term:`array expression`\ s or *None*, where *None* indicates that the name may not be used. (:class:`pytato.array.Placeholder` instances register their names in this way to avoid ambiguity.) .. attribute:: name_gen .. automethod:: __contains__ .. automethod:: __getitem__ .. automethod:: __iter__ .. automethod:: __len__ .. automethod:: assign .. automethod:: copy .. automethod:: ref """ name_gen: UniqueNameGenerator def __init__(self) -> None: self._symbol_table: Dict[str, Array] = {} self.name_gen = UniqueNameGenerator() def __contains__(self, name: object) -> bool: return name in self._symbol_table def __getitem__(self, name: str) -> Array: item = self._symbol_table[name] return item def __iter__(self) -> Iterator[str]: return iter(self._symbol_table) def __len__(self) -> int: return len(self._symbol_table) def copy(self) -> Namespace: from pytato.transform import CopyMapper, copy_namespace return copy_namespace(self, CopyMapper(Namespace())) def assign(self, name: str, value: Array) -> str: """Declare a new array. :param name: a Python identifier :param value: the array object :returns: *name* """ if name in self._symbol_table: raise ValueError(f"'{name}' is already assigned") if not self.name_gen.is_name_conflicting(name): self.name_gen.add_name(name) self._symbol_table[name] = value return name def ref(self, name: str) -> Array: """ :returns: An :term:`array expression` referring to *name*. """ value = self[name] var_ref = prim.Variable(name) if value.shape: var_ref = var_ref[tuple("_%d" % i for i in range(len(value.shape)))] return IndexLambda(self, expr=var_ref, shape=value.shape, dtype=value.dtype)