示例#1
0
    def initialize_session(self, sess: tf_compat.Session):
        """
        Initialize the mask variables for pruning.

        :param sess: the session to use for initializing
        """
        super().initialize_session(sess)
        if self._mask_initializer:
            sess.run(self._mask_initializer)
示例#2
0
    def initialize_session(self, sess: tf_compat.Session):
        """
        Initialize the mask variables for pruning.

        :param sess: the session to use for initializing
        """
        super().initialize_session(sess)
        masks = [op_vars.mask for op_vars in self._prune_op_vars]

        if masks:
            sess.run(tf_compat.variables_initializer(masks))
示例#3
0
def pruning_loss_sens_magnitude(
    graph: tf_compat.Graph = None,
    sess: tf_compat.Session = None,
    sparsity_levels: Union[List[float],
                           Tuple[float,
                                 ...]] = default_pruning_sparsities_loss(True),
) -> PruningLossSensitivityAnalysis:
    """
    Approximated kernel sparsity (pruning) loss analysis for a given model.
    Returns the results for each prunable param (conv, linear) in the model.
    Approximated by taking the magnitudes of the weights.

    :param graph: the graph to inject pruning ops and vars into,
        if not supplied uses get_default_graph()
    :param sess: the session to use
    :param sparsity_levels: the sparsity levels to calculate the loss for for each param
    :return: the analysis results for the model
    """

    if not graph:
        graph = tf_compat.get_default_graph()
    if not sess:
        sess = tf_compat.get_default_session()

    prunable_ops_and_inputs = get_ops_and_inputs_by_name_or_regex(["re:.*"],
                                                                  graph)
    analysis = PruningLossSensitivityAnalysis()

    for op_index, (_, op_tens) in enumerate(prunable_ops_and_inputs):
        weight = sess.run(op_tens)
        values = numpy.sort(numpy.abs(weight.reshape(-1)))
        prev_index = 0

        for sparsity in sparsity_levels:
            val_index = round(sparsity * len(values))

            if val_index >= len(values):
                val_index = len(values) - 1

            if sparsity <= 1e-9:
                baseline = True
                sparsity = 0.0
                sparse_avg = 0.0
            else:
                baseline = False

                if val_index > prev_index:
                    sparse_avg = values[prev_index:val_index].mean().item()
                    prev_index = val_index
                else:
                    sparse_avg = values[val_index].item()
                    prev_index = val_index + 1

            analysis.add_result(None, op_tens.name, op_index, sparsity,
                                sparse_avg, baseline)

    return analysis
示例#4
0
def apply_op_vars_masks(pruning_op_vars: List[PruningOpVars], ks_group: str,
                        sess: tf_compat.Session):
    """
    Apply the masks to the original ops input var so that it can be saved
    with the desired sparsity for later.

    :param pruning_op_vars: the list of named tuples containing the sparse mask
        and the op variable to apply the sparse mask to
    :param ks_group: the group to create the assign ops under
    :param sess: the session to use to run the assign
    """
    for op_vars in pruning_op_vars:
        with tf_compat.name_scope(
                PruningScope.model(op_vars.op, ks_group,
                                   PruningScope.OP_SAVE)):
            masked_var = tf_compat.multiply(op_vars.op_input, op_vars.mask)
            input_var = get_tensor_var(op_vars.op_input)
            assign = tf_compat.assign(input_var, masked_var)
            sess.run(assign)
示例#5
0
def pruning_loss_sens_one_shot(
    op_vars: List[SparsePruningOpVars],
    loss_tensor: tf_compat.Tensor,
    steps_per_measurement: int,
    add_ops_creator: Callable[[int], List[tf_compat.Tensor]] = None,
    feed_dict_creator: Callable[[int], Dict[str, tf_compat.Tensor]] = None,
    sess: tf_compat.Session = None,
    sparsity_levels: List[int] = default_pruning_sparsities_loss(False),
    show_progress: bool = True,
) -> PruningLossSensitivityAnalysis:
    """
    Run a one shot sensitivity analysis for kernel sparsity.
    It does not retrain, and instead puts the model to eval mode.
    Moves operation by operation to calculate the sensitivity analysis for each and
    resets the previously run layers.
    Subsequent sparsity checks for layers and levels will be much faster.

    Note: this should be run once a session has been created and
    the variables have been created for the model.

    Note: the graph should be recreated for later training as this creates
    extra ops in the graph that should be reused before continuing in the system.

    :param op_vars: the created pruning op vars from ks_loss_sensitivity_op_vars
    :param loss_tensor: the loss tensor in the model to measure for the sensitivity
    :param steps_per_measurement: the number of session.run calls to run through
        for each sparsity level on each layer
    :param add_ops_creator: a callback to create an op/tens list to be run through
        the session for each measurement. Called for each measurement
    :param feed_dict_creator: a callback to create a feed dict to be run through
        the session for each measurement. Called for each measurement
    :param sess: the session to use
    :param sparsity_levels: the sparsity levels to check for each layer to calculate
        sensitivity
    :param show_progress: track progress of the runs if True
    :return: the sensitivity results for every op that is prunable
    """

    if not sess:
        sess = tf_compat.get_default_session()

    analysis = PruningLossSensitivityAnalysis()
    sess.run(
        tf_compat.variables_initializer([var.op_vars.mask for var in op_vars]))
    bar = (auto.tqdm(
        desc="KS Analysis",
        total=len(op_vars) * len(sparsity_levels) * steps_per_measurement,
    ) if show_progress else None)

    for op_index, sparse_op_vars in enumerate(op_vars):
        for sparsity_level in sparsity_levels:
            sess.run(
                sparse_op_vars.op_vars.update,
                feed_dict={sparse_op_vars.sparsity: sparsity_level},
            )

            for step in range(steps_per_measurement):
                ops = [loss_tensor]
                add_ops = add_ops_creator(step) if add_ops_creator else None
                feed_dict = feed_dict_creator(
                    step) if feed_dict_creator else None

                if add_ops:
                    ops.extend(add_ops)

                values = sess.run(ops, feed_dict=feed_dict)
                loss = values[0].item()
                analysis.add_result(
                    None,
                    sparse_op_vars.op_vars.op_input.name,
                    op_index,
                    sparsity_level,
                    loss,
                    baseline=sparsity_level < 1e-9,
                )

                if bar is not None:
                    bar.update(1)

        sess.run(sparse_op_vars.op_vars.update,
                 feed_dict={sparse_op_vars.sparsity: 0.0})

    if bar is not None:
        bar.close()

    return analysis