Exemple #1
0
 def repl(m: Match[str]) -> str:
     suffix: str = info['suffix'](m) if callable(info['suffix']) else info['suffix']
     expr: str = info['expr'](name) if 'expr' in info else m.group(0)
     saved.append(SavedAttribute(
         nctype=info['nctype'](name + suffix),
         expr=expr,
     ))
     if 'res' in info:
         replacement: str = info['res'](name)
         return replacement
     return name + suffix
Exemple #2
0
def saved_variables(
    formula: str,
    nctypes: List[NamedCType],
    var_names: Tuple[str, ...],
) -> Tuple[str, Tuple[SavedAttribute, ...]]:

    def stride_expr(name: str) -> str:
        assert var_names == (name,), (
            'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
            'that ".strides()" is being called on.')
        return f'strides_or_error({name}, "{name}")'

    REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [
        # replace self.sizes() with self_sizes
        (r'{}.sizes\(\)', {
            'suffix': '_sizes',
            'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)),
        }),
        # replace self.options() with self_options
        (r'{}.options\(\)', {
            'suffix': '_options',
            'nctype': lambda name: NamedCType(name, BaseCType(tensorOptionsT)),
        }),
        # replace zeros_like(self) with self_info
        (r'zeros_like\({}\)', {
            'suffix': '_info',
            'nctype': lambda name: NamedCType(name, BaseCType(typeAndSizeT)),
            'expr': lambda name: name,  # at save-time
            'res': lambda name: name + '_info.zeros()',  # at eval-time
        }),
        # replace self.size(2) with self_size_2
        (r'{}.size\((\w+)\)', {
            'suffix': lambda m: '_argsize_{}'.format(*m.groups()),
            'nctype': lambda name: NamedCType(name, BaseCType(intT)),
        }),
        # replace self.numel() with self_numel
        (r'{}.numel\(\)', {
            'suffix': '_numel',
            'nctype': lambda name: NamedCType(name, BaseCType(intT)),
        }),
        # replace to_args_sizes(self) with self_args_sizes
        (r'to_args_sizes\({}\)', {
            'suffix': '_args_sizes',
            'nctype': lambda name: NamedCType(name, VectorCType(VectorCType(BaseCType(intT)))),
        }),
        # replace to_args_scalartypes(self) with self_args_scalartypes
        (r'to_args_scalartypes\({}\)', {
            'suffix': '_args_scalartypes',
            'nctype': lambda name: NamedCType(name, VectorCType(BaseCType(scalarTypeT))),
        }),
        # replace TensorGeometry(self) with self_geometry
        (r'TensorGeometry\({}\)', {
            'suffix': '_geometry',
            'nctype': lambda name: NamedCType(name, BaseCType(tensorGeometryT)),
        }),
        (r'{}.scalar_type\(\)', {
            'suffix': '_scalar_type',
            'nctype': lambda name: NamedCType(name, BaseCType(scalarTypeT)),
        }),
        # replace self.dim() with self_dim
        (r'{}.dim\(\)', {
            'suffix': '_dim',
            'nctype': lambda name: NamedCType(name, BaseCType(intT)),
        }),
        # replace self.strides() with self_strides
        (r'{}.strides\(\)', {
            'suffix': '_strides',
            'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)),
            'expr': stride_expr,
        }),
    ]

    # find which arguments need to be saved
    saved: List[SavedAttribute] = []

    for nctype in nctypes:
        name = nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name
        # First search the formula for expressions which can be evaluated
        # when the autograd Function is created to avoid saving variables
        for regex, info in REPLACEMENTS:
            def repl(m: Match[str]) -> str:
                suffix: str = info['suffix'](m) if callable(info['suffix']) else info['suffix']
                expr: str = info['expr'](name) if 'expr' in info else m.group(0)
                saved.append(SavedAttribute(
                    nctype=info['nctype'](name + suffix),
                    expr=expr,
                ))
                if 'res' in info:
                    replacement: str = info['res'](name)
                    return replacement
                return name + suffix

            formula = re.sub(regex.format(name), repl, formula)

        # Find any variables which remain in the formula and save them
        if re.search(IDENT_REGEX.format(name), formula):
            saved.append(SavedAttribute(
                nctype=nctype,
                expr=name,
            ))

    return formula, tuple(saved)