예제 #1
0
파일: slicer.py 프로젝트: jcmgray/cotengra
    def from_info(cls, info, **kwargs):
        """Generate a set of contraction costs from a ``PathInfo`` object.
        """
        cs = []
        size_dict = info.size_dict

        # add all the input 'contractions'
        for term in info.input_subscripts.split(','):
            cs.append({
                'involved': oset(),
                'legs': oset(term),
                'size': compute_size_by_dict(term, size_dict),
                'flops': 0,
            })

        for c in info.contraction_list:
            eq = c[2]
            lhs, rhs = eq.split('->')
            legs = oset(rhs)
            involved = oset.union(*map(oset, lhs.split(',')))

            cs.append({
                'involved': involved,
                'legs': legs,
                'size': compute_size_by_dict(legs, size_dict),
                'flops': flop_count(involved, c[1], 2, size_dict),
            })

        return cls(cs, size_dict)
예제 #2
0
    def from_pathinfo(cls, pathinfo, **kwargs):
        """Generate a set of contraction costs from a ``PathInfo`` object.
        """
        cs = []
        size_dict = pathinfo.size_dict.copy()

        # add all the input 'contractions'
        for term in pathinfo.input_subscripts.split(','):
            cs.append(Contraction(
                involved=frozenset(),
                legs=frozenset(term),
                size=compute_size_by_dict(term, size_dict),
                flops=0,
            ))

        for c in pathinfo.contraction_list:
            eq = c[2]
            lhs, rhs = eq.split('->')
            legs = frozenset(rhs)
            involved = frozenset.union(*map(frozenset, lhs.split(',')))

            cs.append(Contraction(
                involved=involved,
                legs=legs,
                size=compute_size_by_dict(legs, size_dict),
                flops=flop_count(involved, c[1], 2, size_dict),
            ))

        return cls(cs, size_dict)
예제 #3
0
파일: core.py 프로젝트: yangyuan16/cotengra
 def get_flops(self, node):
     """Get the FLOPs for the pairwise contraction that will create
     ``node``.
     """
     try:
         flops = self.info[node]['flops']
     except KeyError:
         if len(node) == 1:
             flops = 0
         else:
             involved = self.get_involved(node)
             removed = self.get_removed(node)
             flops = flop_count(involved, removed, 2, self.size_dict)
         self.info[node]['flops'] = flops
     return flops