def from_info(cls, info, **kwargs): """Generate a set of contraction costs from a ``PathInfo`` object. """ cs = [] size_dict = info.size_dict # add all the input 'contractions' for term in info.input_subscripts.split(','): cs.append({ 'involved': oset(), 'legs': oset(term), 'size': compute_size_by_dict(term, size_dict), 'flops': 0, }) for c in info.contraction_list: eq = c[2] lhs, rhs = eq.split('->') legs = oset(rhs) involved = oset.union(*map(oset, lhs.split(','))) cs.append({ 'involved': involved, 'legs': legs, 'size': compute_size_by_dict(legs, size_dict), 'flops': flop_count(involved, c[1], 2, size_dict), }) return cls(cs, size_dict)
def from_pathinfo(cls, pathinfo, **kwargs): """Generate a set of contraction costs from a ``PathInfo`` object. """ cs = [] size_dict = pathinfo.size_dict.copy() # add all the input 'contractions' for term in pathinfo.input_subscripts.split(','): cs.append(Contraction( involved=frozenset(), legs=frozenset(term), size=compute_size_by_dict(term, size_dict), flops=0, )) for c in pathinfo.contraction_list: eq = c[2] lhs, rhs = eq.split('->') legs = frozenset(rhs) involved = frozenset.union(*map(frozenset, lhs.split(','))) cs.append(Contraction( involved=involved, legs=legs, size=compute_size_by_dict(legs, size_dict), flops=flop_count(involved, c[1], 2, size_dict), )) return cls(cs, size_dict)
def get_size(self, node): """Get the tensor size of ``node``. """ try: size = self.info[node]['size'] except KeyError: size = compute_size_by_dict(self.get_legs(node), self.size_dict) self.info[node]['size'] = size return size
def calc_node_weight_float(term, size_dict, scale='linear'): if scale in ('const', None, False): return 1.0 w = compute_size_by_dict(term, size_dict) # scale up by a thousand so we can add small integer jitter if scale == 'linear': w elif scale == 'log': w = math.log2(w) elif scale == 'exp': w = 2**w return w
def __init__( self, eq, arrays, sliced, optimize='auto', size_dict=None, ): # basic info lhs, self.output = eq.split('->') self.inputs = lhs.split(',') self.arrays = tuple(arrays) self.sliced = tuple(sorted(sliced, key=eq.index)) if size_dict is None: size_dict = create_size_dict(self.inputs, self.arrays) self.size_dict = size_dict # find which arrays are going to be sliced or not self.constant, self.changing = [], [] for i, term in enumerate(self.inputs): if any(ix in self.sliced for ix in term): self.changing.append(i) else: self.constant.append(i) # information about the contraction of a single slice self.eq_sliced = "".join(c for c in eq if c not in sliced) self.sliced_sizes = tuple(self.size_dict[i] for i in self.sliced) self.nslices = compute_size_by_dict(self.sliced, self.size_dict) self.shapes_sliced = tuple( tuple(self.size_dict[i] for i in term) for term in self.eq_sliced.split('->')[0].split(',') ) self.path, self.info_sliced = contract_path( self.eq_sliced, *self.shapes_sliced, shapes=True, optimize=optimize ) # generate the contraction expression self._expr = contract_expression( self.eq_sliced, *self.shapes_sliced, optimize=self.path )
def __init__(self, inputs, output, size_dict, track_childless=False, track_size=False, track_flops=False): self.inputs = tuple(map(frozenset, inputs)) self.N = len(self.inputs) self.size_dict = size_dict # mapping of parents to children - the core binary tree object self.children = {} # information about all the nodes self.info = {} # ... which we can fill in already for final / top node i.e. # the collection of all nodes self.root = frozenset(self.inputs) self.output = frozenset(output) self.add_node(self.root) self.info[self.root]['legs'] = self.output self.info[self.root]['size'] = compute_size_by_dict( self.output, size_dict) # whether to keep track of dangling nodes/subgraphs self.track_childless = track_childless if self.track_childless: # the set of dangling nodes self.childless = {self.root} # running largest_intermediate and total flops self._track_flops = track_flops if track_flops: self._flops = 0 self._track_size = track_size if track_size: self._size = 0