예제 #1
0
 def add_axis_name(x):
     if x is not None:
         if x in self.identifiers:
             raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
         if x == _ellipsis:
             self.identifiers.add(_ellipsis)
             if bracket_group is None:
                 self.composition.append(_ellipsis)
                 self.has_ellipsis_parenthesized = False
             else:
                 bracket_group.append(_ellipsis)
                 self.has_ellipsis_parenthesized = True
         else:
             is_number = str.isdecimal(x)
             if is_number and int(x) == 1:
                 # handling the case of anonymous axis of length 1
                 if bracket_group is None:
                     self.composition.append([])
                 else:
                     pass  # no need to think about 1s inside parenthesis
                 return
             is_axis_name, reason = self.check_axis_name(x, return_reason=True)
             if not (is_number or is_axis_name):
                 raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason))
             if is_number:
                 x = AnonymousAxis(x)
             self.identifiers.add(x)
             if is_number:
                 self.has_non_unitary_anonymous_axes = True
             if bracket_group is None:
                 self.composition.append([x])
             else:
                 bracket_group.append(x)
예제 #2
0
 def __init__(self, value: str):
     self.value = int(value)
     if self.value <= 1:
         if self.value == 1:
             raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue')
         else:
             raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value))
예제 #3
0
 def reshape(self, x, shape):
     if len(shape) == 0:
         return x  # poor support of scalars in mxnet
     if any(isinstance(dimension, UnknownSize) for dimension in shape):
         from einops import EinopsError
         raise EinopsError("Mxnet could't infer all dimensions statically, please provide those with axes_lengths")
     return x.reshape(shape)
예제 #4
0
    def decompose(self, x, known_axes_lengths: dict[str, int]):
        xp = x.__array_namespace__()
        shape = x.shape

        flat_shape = []

        for i, axis_group in enumerate(self.composed_shape):
            unknown_axis_name = None
            known_sizes_prod = 1
            for axis_name in axis_group:
                if axis_name in known_axes_lengths:
                    known_sizes_prod *= known_axes_lengths[axis_name]
                else:
                    if unknown_axis_name is None:
                        unknown_axis_name = axis_name
                    else:
                        raise EinopsError("Can't infer the size")

            if unknown_axis_name is None:
                assert shape[i] == known_sizes_prod
            else:
                known_axes_lengths[
                    unknown_axis_name] = shape[i] // known_sizes_prod

            for axis in axis_group:
                flat_shape.append(known_axes_lengths[axis])

        x = xp.reshape(x, flat_shape)
        return xp.permute_dims(x, self.decompose_transposition)
예제 #5
0
    def __init__(self, expression):
        self.has_ellipsis = False
        self.has_ellipsis_parenthesized = None
        self.identifiers = set()
        # that's axes like 2, 3 or 5. Axes with size 1 are exceptional and replaced with empty composition
        self.has_non_unitary_anonymous_axes = False
        # composition keeps structure of composite axes, see how different corner cases are handled in tests
        self.composition = []
        if '.' in expression:
            if '...' not in expression:
                raise EinopsError('Expression may contain dots only inside ellipsis (...)')
            if str.count(expression, '...') != 1 or str.count(expression, '.') != 3:
                raise EinopsError(
                    'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ')
            expression = expression.replace('...', _ellipsis)
            self.has_ellipsis = True

        bracket_group = None

        def add_axis_name(x):
            if x is not None:
                if x in self.identifiers:
                    raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
                if x == _ellipsis:
                    self.identifiers.add(_ellipsis)
                    if bracket_group is None:
                        self.composition.append(_ellipsis)
                        self.has_ellipsis_parenthesized = False
                    else:
                        bracket_group.append(_ellipsis)
                        self.has_ellipsis_parenthesized = True
                else:
                    is_number = str.isdecimal(x)
                    if is_number and int(x) == 1:
                        # handling the case of anonymous axis of length 1
                        if bracket_group is None:
                            self.composition.append([])
                        else:
                            pass  # no need to think about 1s inside parenthesis
                        return
                    is_axis_name, reason = self.check_axis_name(x, return_reason=True)
                    if not (is_number or is_axis_name):
                        raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason))
                    if is_number:
                        x = AnonymousAxis(x)
                    self.identifiers.add(x)
                    if is_number:
                        self.has_non_unitary_anonymous_axes = True
                    if bracket_group is None:
                        self.composition.append([x])
                    else:
                        bracket_group.append(x)

        current_identifier = None
        for char in expression:
            if char in '() ':
                add_axis_name(current_identifier)
                current_identifier = None
                if char == '(':
                    if bracket_group is not None:
                        raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)")
                    bracket_group = []
                elif char == ')':
                    if bracket_group is None:
                        raise EinopsError('Brackets are not balanced')
                    self.composition.append(bracket_group)
                    bracket_group = None
            elif str.isalnum(char) or char in ['_', _ellipsis]:
                if current_identifier is None:
                    current_identifier = char
                else:
                    current_identifier += char
            else:
                raise EinopsError("Unknown character '{}'".format(char))

        if bracket_group is not None:
            raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression))
        add_axis_name(current_identifier)
예제 #6
0
    def __init__(self, pattern: str):
        """
        :param pattern: example 'b t c <- b hsel wsel c, [hsel, wsel] b t'
        """
        self.pattern = pattern
        left, right = pattern.split('<-')
        arg_split = right.index(',')
        arr_pattern, ind_pattern = right[:arg_split], right[arg_split + 1:]
        ind_pattern = ind_pattern.strip()
        # print(
        #     arr_pattern, '\n',
        #     ind_pattern,
        # )
        assert ind_pattern.startswith(
            '['
        ), 'composition axis should go first in indexer (second argument) [h w] i j k'
        composition_start = ind_pattern.index('[')
        composition_end = ind_pattern.index(']')
        composition = ind_pattern[composition_start + 1:composition_end]
        ind_other_axes = ind_pattern[composition_end + 1:]

        self.result_axes_names = left.split()
        self.array_axes_names = arr_pattern.split()
        self.indexing_axes_names = [x.strip() for x in composition.split(',')]
        self.indexer_other_axes_names = ind_other_axes.split()

        for group_name, group in [
            ('result', self.result_axes_names),
            ('array', self.array_axes_names),
            ('indexer',
             self.indexing_axes_names + self.indexer_other_axes_names),
        ]:
            if len(set(group)) != len(group):
                # need more verbosity, which axis, raise
                raise EinopsError(
                    f'{group_name} pattern ({group}) contains a duplicated axis'
                )

        axis_groups = [
            self.result_axes_names,
            self.array_axes_names,
            self.indexing_axes_names,
            self.indexer_other_axes_names,
        ]

        all_axes = set()
        for group in axis_groups:
            all_axes.update(group)

        self.indexer_axes = []
        self.batch_axes = []
        self.result_and_index_axes = []
        self.result_and_array_axes = []

        for axis in all_axes:
            presence = tuple(axis in g for g in axis_groups)
            # want match-case here. sweet dreams
            if presence == (False, True, True, False):
                self.indexer_axes.append(axis)
            elif presence[2]:
                raise EinopsError(f'Wrong usage of indexer variable {axis}')
            elif presence == (True, True, False, True):
                self.batch_axes.append(axis)
            elif presence == (True, False, False, True):
                self.result_and_index_axes.append(axis)
            elif presence == (True, True, False, False):
                self.result_and_array_axes.append(axis)
            else:
                # TODO better categorization of wrong usage patterns
                raise EinopsError(f'{axis} is used incorrectly in {pattern}')

        assert set(self.indexer_axes) == set(self.indexing_axes_names)
        # order of these variables matters, since we can't lose mapping here
        self.indexer_axes = self.indexing_axes_names

        self.array_composition = CompositionDecomposition(
            decomposed_shape=self.array_axes_names,
            composed_shape=[
                self.batch_axes + self.indexer_axes, self.result_and_array_axes
            ],
        )

        self.index_composition = CompositionDecomposition(
            decomposed_shape=self.indexer_other_axes_names,
            # single axis after composition
            composed_shape=[self.batch_axes + self.result_and_index_axes],
        )

        self.result_composition = CompositionDecomposition(
            decomposed_shape=self.result_axes_names,
            composed_shape=[
                self.batch_axes + self.result_and_index_axes,
                self.result_and_array_axes
            ],
        )
예제 #7
0
def _report_axes(axes: set, report_message: str):
    if len(axes) > 0:
        raise EinopsError(report_message.format(axes))
예제 #8
0
    def __init__(self, pattern, weight_shape, bias_shape=None, **axes_lengths):
        """
        EinMix - Einstein summation with automated tensor management and axis packing/unpacking.

        EinMix is an advanced tool, helpful tutorial:
        https://github.com/arogozhnikov/einops/blob/master/docs/3-einmix-layer.ipynb

        Imagine taking einsum with two arguments, one of each input, and one - tensor with weights
        >>> einsum('time batch channel_in, channel_in channel_out -> time batch channel_out', input, weight)

        This layer manages weights for you, syntax highlights separate role of weight matrix
        >>> EinMix('time batch channel_in -> time batch channel_out', weight_shape='channel_in channel_out')
        But otherwise it is the same einsum under the hood.

        Simple linear layer with bias term (you have one like that in your framework)
        >>> EinMix('t b cin -> t b cout', weight_shape='cin cout', bias_shape='cout', cin=10, cout=20)
        There is restriction to mix the last axis. Let's mix along height
        >>> EinMix('h w c-> hout w c', weight_shape='h hout', bias_shape='hout', h=32, hout=32)
        Channel-wise multiplication (like one used in normalizations)
        >>> EinMix('t b c -> t b c', weight_shape='c', c=128)
        Separate dense layer within each head, no connection between different heads
        >>> EinMix('t b (head cin) -> t b (head cout)', weight_shape='head cin cout', ...)

        ... ah yes, you need to specify all dimensions of weight shape/bias shape in parameters.

        Use cases:
        - when channel dimension is not last, use EinMix, not transposition
        - patch/segment embeddings
        - when need only within-group connections to reduce number of weights and computations
        - perfect as a part of sequential models
        - next-gen MLPs (follow tutorial to learn more)

        Uniform He initialization is applied to weight tensor and encounters for number of elements mixed.

        Parameters
        :param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output
        :param weight_shape: axes of weight. Tensor od this shape is created, stored, and optimized in a layer
        :param bias_shape: axes of bias added to output.
        :param axes_lengths: dimensions of weight tensor
        """
        super().__init__()
        self.pattern = pattern
        self.weight_shape = weight_shape
        self.bias_shape = bias_shape
        self.axes_lengths = axes_lengths

        left_pattern, right_pattern = pattern.split('->')
        left = ParsedExpression(left_pattern)
        right = ParsedExpression(right_pattern)
        weight = ParsedExpression(weight_shape)
        _report_axes(
            set.difference(right.identifiers,
                           {*left.identifiers, *weight.identifiers}),
            'Unrecognized identifiers on the right side of EinMix {}')

        if left.has_ellipsis or right.has_ellipsis or weight.has_ellipsis:
            raise EinopsError(
                'Ellipsis is not supported in EinMix (right now)')
        if any(x.has_non_unitary_anonymous_axes
               for x in [left, right, weight]):
            raise EinopsError(
                'Anonymous axes (numbers) are not allowed in EinMix')
        if '(' in weight_shape or ')' in weight_shape:
            raise EinopsError(
                f'Parenthesis is not allowed in weight shape: {weight_shape}')

        pre_reshape_pattern = None
        pre_reshape_lengths = None
        post_reshape_pattern = None
        if any(len(group) != 1 for group in left.composition):
            names = []
            for group in left.composition:
                names += group
            composition = ' '.join(names)
            pre_reshape_pattern = f'{left_pattern}->{composition}'
            pre_reshape_lengths = {
                name: length
                for name, length in self.axes_lengths.items() if name in names
            }

        if any(len(group) != 1 for group in right.composition):
            names = []
            for group in right.composition:
                names += group
            composition = ' '.join(names)
            post_reshape_pattern = f'{composition}->{right_pattern}'

        self._create_rearrange_layers(pre_reshape_pattern, pre_reshape_lengths,
                                      post_reshape_pattern, {})

        for axis in weight.identifiers:
            if axis not in axes_lengths:
                raise EinopsError(
                    'Dimension {} of weight should be specified'.format(axis))
        _report_axes(
            set.difference(set(axes_lengths),
                           {*left.identifiers, *weight.identifiers}),
            'Axes {} are not used in pattern',
        )
        _report_axes(
            set.difference(weight.identifiers,
                           {*left.identifiers, *right.identifiers}),
            'Weight axes {} are redundant')
        if len(weight.identifiers) == 0:
            warnings.warn(
                'EinMix: weight has no dimensions (means multiplication by a number)'
            )

        _weight_shape = [axes_lengths[axis] for axis, in weight.composition]
        # single output element is a combination of fan_in input elements
        _fan_in = _product([
            axes_lengths[axis] for axis, in weight.composition
            if axis not in right.identifiers
        ])
        if bias_shape is not None:
            if not isinstance(bias_shape, str):
                raise EinopsError(
                    'bias shape should be string specifying which axes bias depends on'
                )
            bias = ParsedExpression(bias_shape)
            _report_axes(set.difference(bias.identifiers, right.identifiers),
                         'Bias axes {} not present in output')
            _report_axes(
                set.difference(bias.identifiers, set(axes_lengths)),
                'Sizes not provided for bias axes {}',
            )

            _bias_shape = []
            for axes in right.composition:
                for axis in axes:
                    if axis in bias.identifiers:
                        _bias_shape.append(axes_lengths[axis])
                    else:
                        _bias_shape.append(1)
        else:
            _bias_shape = None
            _bias_input_size = None

        weight_bound = (3 / _fan_in)**0.5
        bias_bound = (1 / _fan_in)**0.5
        self._create_parameters(_weight_shape, weight_bound, _bias_shape,
                                bias_bound)

        # rewrite einsum expression with single-letter latin identifiers so that
        # expression will be understood by any framework
        mapping2letters = {
            *left.identifiers, *right.identifiers, *weight.identifiers
        }
        mapping2letters = {
            k: letter
            for letter, k in zip(string.ascii_lowercase, mapping2letters)
        }

        def write_flat(axes: list):
            return ''.join(mapping2letters[axis] for axis in axes)

        self.einsum_pattern: str = '{},{}->{}'.format(
            write_flat(left.flat_axes_order()),
            write_flat(weight.flat_axes_order()),
            write_flat(right.flat_axes_order()),
        )