示例#1
0
    def from_info(cls, info, **kwargs):
        """Generate a set of contraction costs from a ``PathInfo`` object.
        """
        cs = []
        size_dict = info.size_dict

        # add all the input 'contractions'
        for term in info.input_subscripts.split(','):
            cs.append({
                'involved': oset(),
                'legs': oset(term),
                'size': compute_size_by_dict(term, size_dict),
                'flops': 0,
            })

        for c in info.contraction_list:
            eq = c[2]
            lhs, rhs = eq.split('->')
            legs = oset(rhs)
            involved = oset.union(*map(oset, lhs.split(',')))

            cs.append({
                'involved': involved,
                'legs': legs,
                'size': compute_size_by_dict(legs, size_dict),
                'flops': flop_count(involved, c[1], 2, size_dict),
            })

        return cls(cs, size_dict)
示例#2
0
    def from_pathinfo(cls, pathinfo, **kwargs):
        """Generate a set of contraction costs from a ``PathInfo`` object.
        """
        cs = []
        size_dict = pathinfo.size_dict.copy()

        # add all the input 'contractions'
        for term in pathinfo.input_subscripts.split(','):
            cs.append(Contraction(
                involved=frozenset(),
                legs=frozenset(term),
                size=compute_size_by_dict(term, size_dict),
                flops=0,
            ))

        for c in pathinfo.contraction_list:
            eq = c[2]
            lhs, rhs = eq.split('->')
            legs = frozenset(rhs)
            involved = frozenset.union(*map(frozenset, lhs.split(',')))

            cs.append(Contraction(
                involved=involved,
                legs=legs,
                size=compute_size_by_dict(legs, size_dict),
                flops=flop_count(involved, c[1], 2, size_dict),
            ))

        return cls(cs, size_dict)
示例#3
0
 def get_size(self, node):
     """Get the tensor size of ``node``.
     """
     try:
         size = self.info[node]['size']
     except KeyError:
         size = compute_size_by_dict(self.get_legs(node), self.size_dict)
         self.info[node]['size'] = size
     return size
示例#4
0
def calc_node_weight_float(term, size_dict, scale='linear'):
    if scale in ('const', None, False):
        return 1.0

    w = compute_size_by_dict(term, size_dict)

    # scale up by a thousand so we can add small integer jitter
    if scale == 'linear':
        w
    elif scale == 'log':
        w = math.log2(w)
    elif scale == 'exp':
        w = 2**w

    return w
示例#5
0
    def __init__(
        self,
        eq,
        arrays,
        sliced,
        optimize='auto',
        size_dict=None,
    ):
        # basic info
        lhs, self.output = eq.split('->')
        self.inputs = lhs.split(',')
        self.arrays = tuple(arrays)
        self.sliced = tuple(sorted(sliced, key=eq.index))
        if size_dict is None:
            size_dict = create_size_dict(self.inputs, self.arrays)
        self.size_dict = size_dict

        # find which arrays are going to be sliced or not
        self.constant, self.changing = [], []
        for i, term in enumerate(self.inputs):
            if any(ix in self.sliced for ix in term):
                self.changing.append(i)
            else:
                self.constant.append(i)

        # information about the contraction of a single slice
        self.eq_sliced = "".join(c for c in eq if c not in sliced)
        self.sliced_sizes = tuple(self.size_dict[i] for i in self.sliced)
        self.nslices = compute_size_by_dict(self.sliced, self.size_dict)
        self.shapes_sliced = tuple(
            tuple(self.size_dict[i] for i in term)
            for term in self.eq_sliced.split('->')[0].split(',')
        )
        self.path, self.info_sliced = contract_path(
            self.eq_sliced, *self.shapes_sliced, shapes=True, optimize=optimize
        )

        # generate the contraction expression
        self._expr = contract_expression(
            self.eq_sliced, *self.shapes_sliced, optimize=self.path
        )
示例#6
0
    def __init__(self,
                 inputs,
                 output,
                 size_dict,
                 track_childless=False,
                 track_size=False,
                 track_flops=False):

        self.inputs = tuple(map(frozenset, inputs))
        self.N = len(self.inputs)
        self.size_dict = size_dict

        # mapping of parents to children - the core binary tree object
        self.children = {}

        # information about all the nodes
        self.info = {}

        # ... which we can fill in already for final / top node i.e.
        # the collection of all nodes
        self.root = frozenset(self.inputs)
        self.output = frozenset(output)
        self.add_node(self.root)
        self.info[self.root]['legs'] = self.output
        self.info[self.root]['size'] = compute_size_by_dict(
            self.output, size_dict)

        # whether to keep track of dangling nodes/subgraphs
        self.track_childless = track_childless
        if self.track_childless:
            # the set of dangling nodes
            self.childless = {self.root}

        # running largest_intermediate and total flops
        self._track_flops = track_flops
        if track_flops:
            self._flops = 0
        self._track_size = track_size
        if track_size:
            self._size = 0