def add_axis_name(x): if x is not None: if x in self.identifiers: raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x)) if x == _ellipsis: self.identifiers.add(_ellipsis) if bracket_group is None: self.composition.append(_ellipsis) self.has_ellipsis_parenthesized = False else: bracket_group.append(_ellipsis) self.has_ellipsis_parenthesized = True else: is_number = str.isdecimal(x) if is_number and int(x) == 1: # handling the case of anonymous axis of length 1 if bracket_group is None: self.composition.append([]) else: pass # no need to think about 1s inside parenthesis return is_axis_name, reason = self.check_axis_name(x, return_reason=True) if not (is_number or is_axis_name): raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason)) if is_number: x = AnonymousAxis(x) self.identifiers.add(x) if is_number: self.has_non_unitary_anonymous_axes = True if bracket_group is None: self.composition.append([x]) else: bracket_group.append(x)
def __init__(self, value: str): self.value = int(value) if self.value <= 1: if self.value == 1: raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue') else: raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value))
def reshape(self, x, shape): if len(shape) == 0: return x # poor support of scalars in mxnet if any(isinstance(dimension, UnknownSize) for dimension in shape): from einops import EinopsError raise EinopsError("Mxnet could't infer all dimensions statically, please provide those with axes_lengths") return x.reshape(shape)
def decompose(self, x, known_axes_lengths: dict[str, int]): xp = x.__array_namespace__() shape = x.shape flat_shape = [] for i, axis_group in enumerate(self.composed_shape): unknown_axis_name = None known_sizes_prod = 1 for axis_name in axis_group: if axis_name in known_axes_lengths: known_sizes_prod *= known_axes_lengths[axis_name] else: if unknown_axis_name is None: unknown_axis_name = axis_name else: raise EinopsError("Can't infer the size") if unknown_axis_name is None: assert shape[i] == known_sizes_prod else: known_axes_lengths[ unknown_axis_name] = shape[i] // known_sizes_prod for axis in axis_group: flat_shape.append(known_axes_lengths[axis]) x = xp.reshape(x, flat_shape) return xp.permute_dims(x, self.decompose_transposition)
def __init__(self, expression): self.has_ellipsis = False self.has_ellipsis_parenthesized = None self.identifiers = set() # that's axes like 2, 3 or 5. Axes with size 1 are exceptional and replaced with empty composition self.has_non_unitary_anonymous_axes = False # composition keeps structure of composite axes, see how different corner cases are handled in tests self.composition = [] if '.' in expression: if '...' not in expression: raise EinopsError('Expression may contain dots only inside ellipsis (...)') if str.count(expression, '...') != 1 or str.count(expression, '.') != 3: raise EinopsError( 'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ') expression = expression.replace('...', _ellipsis) self.has_ellipsis = True bracket_group = None def add_axis_name(x): if x is not None: if x in self.identifiers: raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x)) if x == _ellipsis: self.identifiers.add(_ellipsis) if bracket_group is None: self.composition.append(_ellipsis) self.has_ellipsis_parenthesized = False else: bracket_group.append(_ellipsis) self.has_ellipsis_parenthesized = True else: is_number = str.isdecimal(x) if is_number and int(x) == 1: # handling the case of anonymous axis of length 1 if bracket_group is None: self.composition.append([]) else: pass # no need to think about 1s inside parenthesis return is_axis_name, reason = self.check_axis_name(x, return_reason=True) if not (is_number or is_axis_name): raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason)) if is_number: x = AnonymousAxis(x) self.identifiers.add(x) if is_number: self.has_non_unitary_anonymous_axes = True if bracket_group is None: self.composition.append([x]) else: bracket_group.append(x) current_identifier = None for char in expression: if char in '() ': add_axis_name(current_identifier) current_identifier = None if char == '(': if bracket_group is not None: raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)") bracket_group = [] elif char == ')': if bracket_group is None: raise EinopsError('Brackets are not balanced') self.composition.append(bracket_group) bracket_group = None elif str.isalnum(char) or char in ['_', _ellipsis]: if current_identifier is None: current_identifier = char else: current_identifier += char else: raise EinopsError("Unknown character '{}'".format(char)) if bracket_group is not None: raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression)) add_axis_name(current_identifier)
def __init__(self, pattern: str): """ :param pattern: example 'b t c <- b hsel wsel c, [hsel, wsel] b t' """ self.pattern = pattern left, right = pattern.split('<-') arg_split = right.index(',') arr_pattern, ind_pattern = right[:arg_split], right[arg_split + 1:] ind_pattern = ind_pattern.strip() # print( # arr_pattern, '\n', # ind_pattern, # ) assert ind_pattern.startswith( '[' ), 'composition axis should go first in indexer (second argument) [h w] i j k' composition_start = ind_pattern.index('[') composition_end = ind_pattern.index(']') composition = ind_pattern[composition_start + 1:composition_end] ind_other_axes = ind_pattern[composition_end + 1:] self.result_axes_names = left.split() self.array_axes_names = arr_pattern.split() self.indexing_axes_names = [x.strip() for x in composition.split(',')] self.indexer_other_axes_names = ind_other_axes.split() for group_name, group in [ ('result', self.result_axes_names), ('array', self.array_axes_names), ('indexer', self.indexing_axes_names + self.indexer_other_axes_names), ]: if len(set(group)) != len(group): # need more verbosity, which axis, raise raise EinopsError( f'{group_name} pattern ({group}) contains a duplicated axis' ) axis_groups = [ self.result_axes_names, self.array_axes_names, self.indexing_axes_names, self.indexer_other_axes_names, ] all_axes = set() for group in axis_groups: all_axes.update(group) self.indexer_axes = [] self.batch_axes = [] self.result_and_index_axes = [] self.result_and_array_axes = [] for axis in all_axes: presence = tuple(axis in g for g in axis_groups) # want match-case here. sweet dreams if presence == (False, True, True, False): self.indexer_axes.append(axis) elif presence[2]: raise EinopsError(f'Wrong usage of indexer variable {axis}') elif presence == (True, True, False, True): self.batch_axes.append(axis) elif presence == (True, False, False, True): self.result_and_index_axes.append(axis) elif presence == (True, True, False, False): self.result_and_array_axes.append(axis) else: # TODO better categorization of wrong usage patterns raise EinopsError(f'{axis} is used incorrectly in {pattern}') assert set(self.indexer_axes) == set(self.indexing_axes_names) # order of these variables matters, since we can't lose mapping here self.indexer_axes = self.indexing_axes_names self.array_composition = CompositionDecomposition( decomposed_shape=self.array_axes_names, composed_shape=[ self.batch_axes + self.indexer_axes, self.result_and_array_axes ], ) self.index_composition = CompositionDecomposition( decomposed_shape=self.indexer_other_axes_names, # single axis after composition composed_shape=[self.batch_axes + self.result_and_index_axes], ) self.result_composition = CompositionDecomposition( decomposed_shape=self.result_axes_names, composed_shape=[ self.batch_axes + self.result_and_index_axes, self.result_and_array_axes ], )
def _report_axes(axes: set, report_message: str): if len(axes) > 0: raise EinopsError(report_message.format(axes))
def __init__(self, pattern, weight_shape, bias_shape=None, **axes_lengths): """ EinMix - Einstein summation with automated tensor management and axis packing/unpacking. EinMix is an advanced tool, helpful tutorial: https://github.com/arogozhnikov/einops/blob/master/docs/3-einmix-layer.ipynb Imagine taking einsum with two arguments, one of each input, and one - tensor with weights >>> einsum('time batch channel_in, channel_in channel_out -> time batch channel_out', input, weight) This layer manages weights for you, syntax highlights separate role of weight matrix >>> EinMix('time batch channel_in -> time batch channel_out', weight_shape='channel_in channel_out') But otherwise it is the same einsum under the hood. Simple linear layer with bias term (you have one like that in your framework) >>> EinMix('t b cin -> t b cout', weight_shape='cin cout', bias_shape='cout', cin=10, cout=20) There is restriction to mix the last axis. Let's mix along height >>> EinMix('h w c-> hout w c', weight_shape='h hout', bias_shape='hout', h=32, hout=32) Channel-wise multiplication (like one used in normalizations) >>> EinMix('t b c -> t b c', weight_shape='c', c=128) Separate dense layer within each head, no connection between different heads >>> EinMix('t b (head cin) -> t b (head cout)', weight_shape='head cin cout', ...) ... ah yes, you need to specify all dimensions of weight shape/bias shape in parameters. Use cases: - when channel dimension is not last, use EinMix, not transposition - patch/segment embeddings - when need only within-group connections to reduce number of weights and computations - perfect as a part of sequential models - next-gen MLPs (follow tutorial to learn more) Uniform He initialization is applied to weight tensor and encounters for number of elements mixed. Parameters :param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output :param weight_shape: axes of weight. Tensor od this shape is created, stored, and optimized in a layer :param bias_shape: axes of bias added to output. :param axes_lengths: dimensions of weight tensor """ super().__init__() self.pattern = pattern self.weight_shape = weight_shape self.bias_shape = bias_shape self.axes_lengths = axes_lengths left_pattern, right_pattern = pattern.split('->') left = ParsedExpression(left_pattern) right = ParsedExpression(right_pattern) weight = ParsedExpression(weight_shape) _report_axes( set.difference(right.identifiers, {*left.identifiers, *weight.identifiers}), 'Unrecognized identifiers on the right side of EinMix {}') if left.has_ellipsis or right.has_ellipsis or weight.has_ellipsis: raise EinopsError( 'Ellipsis is not supported in EinMix (right now)') if any(x.has_non_unitary_anonymous_axes for x in [left, right, weight]): raise EinopsError( 'Anonymous axes (numbers) are not allowed in EinMix') if '(' in weight_shape or ')' in weight_shape: raise EinopsError( f'Parenthesis is not allowed in weight shape: {weight_shape}') pre_reshape_pattern = None pre_reshape_lengths = None post_reshape_pattern = None if any(len(group) != 1 for group in left.composition): names = [] for group in left.composition: names += group composition = ' '.join(names) pre_reshape_pattern = f'{left_pattern}->{composition}' pre_reshape_lengths = { name: length for name, length in self.axes_lengths.items() if name in names } if any(len(group) != 1 for group in right.composition): names = [] for group in right.composition: names += group composition = ' '.join(names) post_reshape_pattern = f'{composition}->{right_pattern}' self._create_rearrange_layers(pre_reshape_pattern, pre_reshape_lengths, post_reshape_pattern, {}) for axis in weight.identifiers: if axis not in axes_lengths: raise EinopsError( 'Dimension {} of weight should be specified'.format(axis)) _report_axes( set.difference(set(axes_lengths), {*left.identifiers, *weight.identifiers}), 'Axes {} are not used in pattern', ) _report_axes( set.difference(weight.identifiers, {*left.identifiers, *right.identifiers}), 'Weight axes {} are redundant') if len(weight.identifiers) == 0: warnings.warn( 'EinMix: weight has no dimensions (means multiplication by a number)' ) _weight_shape = [axes_lengths[axis] for axis, in weight.composition] # single output element is a combination of fan_in input elements _fan_in = _product([ axes_lengths[axis] for axis, in weight.composition if axis not in right.identifiers ]) if bias_shape is not None: if not isinstance(bias_shape, str): raise EinopsError( 'bias shape should be string specifying which axes bias depends on' ) bias = ParsedExpression(bias_shape) _report_axes(set.difference(bias.identifiers, right.identifiers), 'Bias axes {} not present in output') _report_axes( set.difference(bias.identifiers, set(axes_lengths)), 'Sizes not provided for bias axes {}', ) _bias_shape = [] for axes in right.composition: for axis in axes: if axis in bias.identifiers: _bias_shape.append(axes_lengths[axis]) else: _bias_shape.append(1) else: _bias_shape = None _bias_input_size = None weight_bound = (3 / _fan_in)**0.5 bias_bound = (1 / _fan_in)**0.5 self._create_parameters(_weight_shape, weight_bound, _bias_shape, bias_bound) # rewrite einsum expression with single-letter latin identifiers so that # expression will be understood by any framework mapping2letters = { *left.identifiers, *right.identifiers, *weight.identifiers } mapping2letters = { k: letter for letter, k in zip(string.ascii_lowercase, mapping2letters) } def write_flat(axes: list): return ''.join(mapping2letters[axis] for axis in axes) self.einsum_pattern: str = '{},{}->{}'.format( write_flat(left.flat_axes_order()), write_flat(weight.flat_axes_order()), write_flat(right.flat_axes_order()), )