Ejemplo n.º 1
0
 def _spec_from_threshold(self, net_sens, threshold, excludes=[]):
     groups = pruning_lib.prunable_groups_by_threshold(
         net_sens, threshold, excludes)
     spec = pruning_lib.PruningSpec()
     for group in groups:
         spec.add_group(group)
     return spec
Ejemplo n.º 2
0
    def prune(self, threshold=None, sparsity=None):
        spec = pruning_lib.PruningSpec()
        if threshold:
            sens_path = pruning_lib.sens_path(self._graph)
            if not os.path.exists(sens_path):
                raise RuntimeError("Must call ana() before runnig prune.")
            net_sens = pruning_lib.read_sens(sens_path)

            # TODO(yuwang): Support excludes: important to detection net.
            net_sparsity = pruning_lib.get_sparsity_by_threshold(
                net_sens, threshold)
            logging.vlog(
                1, 'NetSparsity: \n{}'.format('\n'.join(
                    [str(group) for group in net_sparsity])))
            for group_sparsity in net_sparsity:
                spec.add_group(group_sparsity)
        elif sparsity:
            groups = pruning_lib.group_nodes(self._graph)
            for group in groups:
                spec.add_group(pruning_lib.GroupSparsity(group, sparsity))
        else:
            raise ValueError(
                "At least one of 'sparsity' or 'threshold' to be set")

        pruned_model, pruning_info = self._prune(self._graph, spec)
        return PruningModule(pruned_model, pruning_info)
Ejemplo n.º 3
0
 def spec(self, step):
     spec = pruning_lib.PruningSpec()
     group_idx, exp = self._group_and_exp(step)
     if group_idx >= 0:
         nodes = self._groups[group_idx]
         spec.add_group(pruning_lib.GroupSparsity(nodes, (exp + 1) * 0.1))
     return spec
Ejemplo n.º 4
0
 def spec(self, step):
   spec = pruning_lib.PruningSpec()
   if step > 0:
     group_idx, sparsity = self._eval_plan(step)
     nodes = self._groups[group_idx]
     spec.add_group(pruning_lib.PrunableGroup(nodes, sparsity))
   # Empty spec for baseline.
   return spec