def from_info(cls, info, **kwargs): """Generate a set of contraction costs from a ``PathInfo`` object. """ cs = [] size_dict = info.size_dict # add all the input 'contractions' for term in info.input_subscripts.split(','): cs.append({ 'involved': oset(), 'legs': oset(term), 'size': compute_size_by_dict(term, size_dict), 'flops': 0, }) for c in info.contraction_list: eq = c[2] lhs, rhs = eq.split('->') legs = oset(rhs) involved = oset.union(*map(oset, lhs.split(','))) cs.append({ 'involved': involved, 'legs': legs, 'size': compute_size_by_dict(legs, size_dict), 'flops': flop_count(involved, c[1], 2, size_dict), }) return cls(cs, size_dict)
def from_pathinfo(cls, pathinfo, **kwargs): """Generate a set of contraction costs from a ``PathInfo`` object. """ cs = [] size_dict = pathinfo.size_dict.copy() # add all the input 'contractions' for term in pathinfo.input_subscripts.split(','): cs.append(Contraction( involved=frozenset(), legs=frozenset(term), size=compute_size_by_dict(term, size_dict), flops=0, )) for c in pathinfo.contraction_list: eq = c[2] lhs, rhs = eq.split('->') legs = frozenset(rhs) involved = frozenset.union(*map(frozenset, lhs.split(','))) cs.append(Contraction( involved=involved, legs=legs, size=compute_size_by_dict(legs, size_dict), flops=flop_count(involved, c[1], 2, size_dict), )) return cls(cs, size_dict)
def get_flops(self, node): """Get the FLOPs for the pairwise contraction that will create ``node``. """ try: flops = self.info[node]['flops'] except KeyError: if len(node) == 1: flops = 0 else: involved = self.get_involved(node) removed = self.get_removed(node) flops = flop_count(involved, removed, 2, self.size_dict) self.info[node]['flops'] = flops return flops