예제 #1
0
    def initialize_session(self, sess: tf_compat.Session):
        """
        Initialize the mask variables for pruning.

        :param sess: the session to use for initializing
        """
        super().initialize_session(sess)
        masks = [op_vars.mask for op_vars in self._prune_op_vars]

        if masks:
            sess.run(tf_compat.variables_initializer(masks))
예제 #2
0
    def create_metrics(
        self,
        net_outputs: Union[tf_compat.Tensor, Dict[str, tf_compat.Tensor]],
        labels: Union[tf_compat.Tensor, Dict[str, tf_compat.Tensor]],
        params: Dict[str, Any],
    ) -> (
            Dict[str, Tuple[tf_compat.Tensor, tf_compat.Operation]],
            Dict[str, tf_compat.Operation],
    ):
        """
        Create metrics for evaluation

        :param net_outputs: output tensors of the model graph
        :param labels: ground truth labels
        :param params: the model function params
        :return: dictionary of metrics and their reset operations
        """
        metrics = params.get("metrics", [])

        metrics_dict = {}
        metrics_initializers_dict = {}
        with tf_compat.name_scope("metrics"):
            for metric in metrics:
                if metric == "accuracy":
                    labels_argmax = tf_compat.argmax(labels, 1)
                    net_outputs_argmax = tf_compat.argmax(net_outputs, 1)
                    metrics_dict["accuracy"] = tf_compat.metrics.accuracy(
                        labels_argmax,
                        net_outputs_argmax,
                        name="accuracy_metric",
                    )
                    # The total and count variables created to support accuracy
                    running_vars = tf_compat.get_collection(
                        tf_compat.GraphKeys.LOCAL_VARIABLES,
                        scope="metrics/accuracy_metric",
                    )
                    running_vars_initializer = tf_compat.variables_initializer(
                        var_list=running_vars)
                    metrics_initializers_dict[
                        metric] = running_vars_initializer
                else:
                    raise ValueError("Unsupported metric: {}".format(metric))

        return (metrics_dict, metrics_initializers_dict)
예제 #3
0
def pruning_loss_sens_one_shot(
    op_vars: List[SparsePruningOpVars],
    loss_tensor: tf_compat.Tensor,
    steps_per_measurement: int,
    add_ops_creator: Callable[[int], List[tf_compat.Tensor]] = None,
    feed_dict_creator: Callable[[int], Dict[str, tf_compat.Tensor]] = None,
    sess: tf_compat.Session = None,
    sparsity_levels: List[int] = default_pruning_sparsities_loss(False),
    show_progress: bool = True,
) -> PruningLossSensitivityAnalysis:
    """
    Run a one shot sensitivity analysis for kernel sparsity.
    It does not retrain, and instead puts the model to eval mode.
    Moves operation by operation to calculate the sensitivity analysis for each and
    resets the previously run layers.
    Subsequent sparsity checks for layers and levels will be much faster.

    Note: this should be run once a session has been created and
    the variables have been created for the model.

    Note: the graph should be recreated for later training as this creates
    extra ops in the graph that should be reused before continuing in the system.

    :param op_vars: the created pruning op vars from ks_loss_sensitivity_op_vars
    :param loss_tensor: the loss tensor in the model to measure for the sensitivity
    :param steps_per_measurement: the number of session.run calls to run through
        for each sparsity level on each layer
    :param add_ops_creator: a callback to create an op/tens list to be run through
        the session for each measurement. Called for each measurement
    :param feed_dict_creator: a callback to create a feed dict to be run through
        the session for each measurement. Called for each measurement
    :param sess: the session to use
    :param sparsity_levels: the sparsity levels to check for each layer to calculate
        sensitivity
    :param show_progress: track progress of the runs if True
    :return: the sensitivity results for every op that is prunable
    """

    if not sess:
        sess = tf_compat.get_default_session()

    analysis = PruningLossSensitivityAnalysis()
    sess.run(
        tf_compat.variables_initializer([var.op_vars.mask for var in op_vars]))
    bar = (auto.tqdm(
        desc="KS Analysis",
        total=len(op_vars) * len(sparsity_levels) * steps_per_measurement,
    ) if show_progress else None)

    for op_index, sparse_op_vars in enumerate(op_vars):
        for sparsity_level in sparsity_levels:
            sess.run(
                sparse_op_vars.op_vars.update,
                feed_dict={sparse_op_vars.sparsity: sparsity_level},
            )

            for step in range(steps_per_measurement):
                ops = [loss_tensor]
                add_ops = add_ops_creator(step) if add_ops_creator else None
                feed_dict = feed_dict_creator(
                    step) if feed_dict_creator else None

                if add_ops:
                    ops.extend(add_ops)

                values = sess.run(ops, feed_dict=feed_dict)
                loss = values[0].item()
                analysis.add_result(
                    None,
                    sparse_op_vars.op_vars.op_input.name,
                    op_index,
                    sparsity_level,
                    loss,
                    baseline=sparsity_level < 1e-9,
                )

                if bar is not None:
                    bar.update(1)

        sess.run(sparse_op_vars.op_vars.update,
                 feed_dict={sparse_op_vars.sparsity: 0.0})

    if bar is not None:
        bar.close()

    return analysis
예제 #4
0
    def create_ops(
        self,
        steps_per_epoch: int,
        global_step: tf_compat.Tensor,
        graph: tf_compat.Graph,
    ) -> Tuple[List[Union[tf_compat.Tensor, tf_compat.Operation]], Dict[str,
                                                                        Any]]:
        """
        Create the sparsity ops to modify the training graph according to the settings
        for the current instance.

        :param steps_per_epoch: the number of steps (batches) per training epoch
        :param global_step: the global step used while training
        :param graph: the graph to be modified
        :return: a tuple (list of ops, dict of named ops / tensors)
            to be run or used for modifying the training process.
        """
        mod_ops, mod_extras = super().create_ops(graph, steps_per_epoch,
                                                 global_step)
        start_step, end_step = self.start_end_steps(steps_per_epoch,
                                                    after_optim=True)
        update_frequency_step = self.update_frequency_steps(steps_per_epoch)
        params = (
            self._params if self._params != ALL_TOKEN else [
                clean_tensor_name(var.name) for _, var in
                # Have ALL_TOKEN match to all variable names for now
                get_ops_and_inputs_by_name_or_regex(["re:.*"], graph)
            ])

        with graph.as_default():
            (
                update_op,
                prune_op_vars,
                update_ready,
                sparsity,
            ) = get_or_create_ks_scheduled_graph_ops(
                graph,
                global_step,
                params,
                start_step,
                end_step,
                update_frequency_step,
                self._init_sparsity,
                self._final_sparsity,
                self.exponent,
                self._leave_enabled,
                self.ks_group,
                self._mask_creator,
            )

            if self.log_types == ALL_TOKEN or "tensorboard" in self.log_types:
                mod_extras[EXTRAS_KEY_SUMMARIES] = create_summaries_pruning(
                    prune_op_vars)

        mod_ops.append(update_op)
        self._prune_op_vars = prune_op_vars
        self._update_ready = update_ready
        self._sparsity = sparsity

        # Create and cache the mask initializers to be run
        # through initialize_session. When using the estimator,
        # the initialization is done as part of the init_fn of
        # the training scaffold object, at which the graph cannot
        # be changed (hence the creation and caching)
        masks = [op_vars.mask for op_vars in self._prune_op_vars]
        self._mask_initializer = (tf_compat.variables_initializer(masks)
                                  if masks else None)

        return mod_ops, mod_extras