def trial_greedy(inputs, output, size_dict, random_strength=0.1, temperature=1.0, rel_temperature=True, costmod=1, usesizes=True): rand_size_dict = jitter_dict(size_dict, random_strength) cost_fn = functools.partial(cost_memory_removed_mod, costmod=costmod, usesizes=usesizes) choose_fn = functools.partial(thermal_chooser, temperature=temperature, rel_temperature=rel_temperature) ssa_path = ssa_greedy_optimize(inputs, output, rand_size_dict, choose_fn=choose_fn, cost_fn=cost_fn) return ContractionTree.from_path(inputs, output, size_dict, ssa_path=ssa_path)
def greconf_rc(inputs, output, size_dict, memory_limit=None): """Greedy-reconf path -- find a single greedy path then perform a round of cheap subtree reconfigurations to optimize it. """ ssa_path = ssa_greedy_optimize(inputs, output, size_dict) tree = ContractionTree.from_path(inputs, output, size_dict, ssa_path=ssa_path) tree.subtree_reconfigure_(subtree_size=6, minimize='combo') return tree.get_path()
def trial_greedy(inputs, output, size_dict, random_strength=0.1, temperature=1.0, rel_temperature=True, costmod=1): rand_size_dict = jitter_dict(size_dict, random_strength) cost_fn = functools.partial(cost_memory_removed_mod, costmod=costmod) choose_fn = functools.partial(thermal_chooser, temperature=temperature, rel_temperature=rel_temperature) ssa_path = ssa_greedy_optimize(inputs, output, rand_size_dict, choose_fn=choose_fn, cost_fn=cost_fn) ctree = ContractionTree.from_path(inputs, output, size_dict, ssa_path=ssa_path) return {'tree': ctree, 'ssa_path': ssa_path, 'flops': ctree.total_flops(), 'size': ctree.max_size()}
def get_ssa_path(self, inputs, output, size_dict): self.hg = get_hypergraph(inputs, output, size_dict, accel='auto') self.cents = self.hg.simple_centrality() def region_choose_sorter(node): return self.cents[node] + 1e-6 * random.random() if output: region = oset(self.hg.output_nodes()) elif self.start == 'max': region = oset([max(self.cents.keys(), key=region_choose_sorter)]) elif self.start == 'min': region = oset([min(self.cents.keys(), key=region_choose_sorter)]) else: region = oset(self.start) candidates = [] merges = {} distances = self.hg.simple_distance(list(region), p=self.distance_p) connectivity = collections.defaultdict(lambda: 0) if len(region) == 1: seq = [] elif len(region) == 2: seq = [tuple(region)] else: # span will have multiple starting points, contract these o_nodes = list(region) o_inputs = [inputs[i] for i in o_nodes] o_ssa_path = ssa_greedy_optimize(o_inputs, output, size_dict) seq = [] for pi, pj in o_ssa_path: merges[o_nodes[pi]] = o_nodes[pj] seq.append((o_nodes[pi], o_nodes[pj])) o_nodes.append(o_nodes[pj]) seq.reverse() def _check_candidate(i_surface, i_neighbor): if (i_neighbor in region): return if i_neighbor in merges: i_current = merges[i_neighbor] if self.distance_steal == "abs": if distances[i_surface] < distances[i_current]: merges[i_neighbor] = i_surface elif self.distance_steal == 'rel': old_diff = abs(distances[i_current] - distances[i_neighbor]) new_diff = abs(distances[i_surface] - distances[i_neighbor]) if new_diff > old_diff: merges[i_neighbor] = i_surface else: merges[i_neighbor] = i_surface candidates.append(i_neighbor) if self.weight_bonds: connectivity[i_neighbor] += math.log2( self.hg.bond_size(i_surface, i_neighbor)) else: connectivity[i_neighbor] += 1 def _sorter(i): scores = { 'C': self.coeff_connectivity * connectivity[i], 'N': self.coeff_ndim * len(inputs[i]), 'D': self.coeff_distance * distances[i], 'L': self.coeff_next_centrality * self.cents[i], 'T': max(0.0, self.temperature) * gumbel(), 'I': -i, } if self.score_perm == '': return sum(scores[o] for o in 'CNDLT') c = tuple(scores[o] for o in self.score_perm) return c for i in region: for j in self.hg.neighbors(i): _check_candidate(i, j) while candidates: candidates.sort(key=_sorter) i_surface = candidates.pop() region.add(i_surface) for i_next in self.hg.neighbors(i_surface): _check_candidate(i_surface, i_next) seq.append((i_surface, merges[i_surface])) seq.reverse() ssapath = [] ssa = self.hg.get_num_nodes() node2ssa = {i: i for i in range(ssa)} for i, j in seq: ssapath.append((node2ssa[i], node2ssa[j])) node2ssa[j] = ssa ssa += 1 return ssapath