Exemple #1
0
def trial_greedy(inputs,
                 output,
                 size_dict,
                 random_strength=0.1,
                 temperature=1.0,
                 rel_temperature=True,
                 costmod=1,
                 usesizes=True):

    rand_size_dict = jitter_dict(size_dict, random_strength)

    cost_fn = functools.partial(cost_memory_removed_mod,
                                costmod=costmod,
                                usesizes=usesizes)
    choose_fn = functools.partial(thermal_chooser,
                                  temperature=temperature,
                                  rel_temperature=rel_temperature)

    ssa_path = ssa_greedy_optimize(inputs,
                                   output,
                                   rand_size_dict,
                                   choose_fn=choose_fn,
                                   cost_fn=cost_fn)

    return ContractionTree.from_path(inputs,
                                     output,
                                     size_dict,
                                     ssa_path=ssa_path)
Exemple #2
0
def greconf_rc(inputs, output, size_dict, memory_limit=None):
    """Greedy-reconf path -- find a single greedy path then perform a round of
    cheap subtree reconfigurations to optimize it.
    """
    ssa_path = ssa_greedy_optimize(inputs, output, size_dict)
    tree = ContractionTree.from_path(inputs,
                                     output,
                                     size_dict,
                                     ssa_path=ssa_path)
    tree.subtree_reconfigure_(subtree_size=6, minimize='combo')
    return tree.get_path()
Exemple #3
0
def trial_greedy(inputs, output, size_dict,
                 random_strength=0.1,
                 temperature=1.0,
                 rel_temperature=True,
                 costmod=1):

    rand_size_dict = jitter_dict(size_dict, random_strength)

    cost_fn = functools.partial(cost_memory_removed_mod, costmod=costmod)
    choose_fn = functools.partial(thermal_chooser, temperature=temperature,
                                  rel_temperature=rel_temperature)

    ssa_path = ssa_greedy_optimize(inputs, output, rand_size_dict,
                                   choose_fn=choose_fn, cost_fn=cost_fn)

    ctree = ContractionTree.from_path(inputs, output, size_dict,
                                      ssa_path=ssa_path)

    return {'tree': ctree, 'ssa_path': ssa_path,
            'flops': ctree.total_flops(), 'size': ctree.max_size()}
Exemple #4
0
    def get_ssa_path(self, inputs, output, size_dict):
        self.hg = get_hypergraph(inputs, output, size_dict, accel='auto')
        self.cents = self.hg.simple_centrality()

        def region_choose_sorter(node):
            return self.cents[node] + 1e-6 * random.random()

        if output:
            region = oset(self.hg.output_nodes())
        elif self.start == 'max':
            region = oset([max(self.cents.keys(), key=region_choose_sorter)])
        elif self.start == 'min':
            region = oset([min(self.cents.keys(), key=region_choose_sorter)])
        else:
            region = oset(self.start)

        candidates = []
        merges = {}
        distances = self.hg.simple_distance(list(region), p=self.distance_p)
        connectivity = collections.defaultdict(lambda: 0)

        if len(region) == 1:
            seq = []
        elif len(region) == 2:
            seq = [tuple(region)]
        else:
            # span will have multiple starting points, contract these
            o_nodes = list(region)
            o_inputs = [inputs[i] for i in o_nodes]
            o_ssa_path = ssa_greedy_optimize(o_inputs, output, size_dict)
            seq = []
            for pi, pj in o_ssa_path:
                merges[o_nodes[pi]] = o_nodes[pj]
                seq.append((o_nodes[pi], o_nodes[pj]))
                o_nodes.append(o_nodes[pj])
            seq.reverse()

        def _check_candidate(i_surface, i_neighbor):
            if (i_neighbor in region):
                return

            if i_neighbor in merges:
                i_current = merges[i_neighbor]

                if self.distance_steal == "abs":
                    if distances[i_surface] < distances[i_current]:
                        merges[i_neighbor] = i_surface

                elif self.distance_steal == 'rel':
                    old_diff = abs(distances[i_current] -
                                   distances[i_neighbor])
                    new_diff = abs(distances[i_surface] -
                                   distances[i_neighbor])
                    if new_diff > old_diff:
                        merges[i_neighbor] = i_surface
            else:
                merges[i_neighbor] = i_surface
                candidates.append(i_neighbor)

            if self.weight_bonds:
                connectivity[i_neighbor] += math.log2(
                    self.hg.bond_size(i_surface, i_neighbor))
            else:
                connectivity[i_neighbor] += 1

        def _sorter(i):
            scores = {
                'C': self.coeff_connectivity * connectivity[i],
                'N': self.coeff_ndim * len(inputs[i]),
                'D': self.coeff_distance * distances[i],
                'L': self.coeff_next_centrality * self.cents[i],
                'T': max(0.0, self.temperature) * gumbel(),
                'I': -i,
            }
            if self.score_perm == '':
                return sum(scores[o] for o in 'CNDLT')
            c = tuple(scores[o] for o in self.score_perm)
            return c

        for i in region:
            for j in self.hg.neighbors(i):
                _check_candidate(i, j)

        while candidates:
            candidates.sort(key=_sorter)
            i_surface = candidates.pop()
            region.add(i_surface)
            for i_next in self.hg.neighbors(i_surface):
                _check_candidate(i_surface, i_next)
            seq.append((i_surface, merges[i_surface]))
        seq.reverse()

        ssapath = []
        ssa = self.hg.get_num_nodes()
        node2ssa = {i: i for i in range(ssa)}
        for i, j in seq:
            ssapath.append((node2ssa[i], node2ssa[j]))
            node2ssa[j] = ssa
            ssa += 1

        return ssapath