Beispiel #1
0
        def _calc_lr():
            less = tf_compat.cast(
                tf_compat.greater_equal(global_step, milestone_steps),
                tf_compat.int64)
            updates = tf_compat.reduce_sum(less)
            mult_g = tf_compat.pow(gamma,
                                   tf_compat.cast(updates, tf_compat.float32))

            return tf_compat.multiply(init_lr, mult_g)
Beispiel #2
0
def create_op_pruning_no_update(
    op: tf_compat.Operation,
    op_input: tf_compat.Tensor,
    ks_group: str,
    leave_enabled: bool = True,
    is_after_end_step: tf_compat.Tensor = None,
) -> PruningOpVars:
    """
    Creates the necessary variables and operators to gradually
    apply sparsity to an operators variable without returning a
    PruningOpVars.update value.

    :param op: the operation to prune to the given sparsity
    :param op_input: the parameter within the op to create a mask for
    :param ks_group: the group identifier the scope should be created under
        mask_creator
    :param leave_enabled: True to continue masking the weights after end_epoch,
        False to stop masking
    :param is_after_end_step: only should be provided if leave_enabled is False;
        tensor that is true if the current global step is after end_epoch
    :return: a named tuple containing the assignment op, mask variable,
        threshold tensor, and masked tensor
    """
    if tf_contrib_err:
        raise tf_contrib_err

    op_sgv = graph_editor.sgv(op)

    # create the necessary variables first
    with tf_compat.variable_scope(PruningScope.model(op, ks_group),
                                  reuse=tf_compat.AUTO_REUSE):
        mask = tf_compat.get_variable(
            PruningScope.VAR_MASK,
            op_input.get_shape(),
            initializer=tf_compat.ones_initializer(),
            trainable=False,
            dtype=op_input.dtype,
        )
    tf_compat.add_to_collection(
        PruningScope.collection_name(ks_group, PruningScope.VAR_MASK), mask)

    # create the masked operation and assign as the new input to the op
    with tf_compat.name_scope(
            PruningScope.model(op, ks_group, trailing_slash=True)):
        masked = tf_compat.multiply(mask, op_input, PruningScope.OP_MASKED_VAR)
        op_inp_tens = (masked if leave_enabled else tf_compat.cond(
            is_after_end_step, lambda: op_input, lambda: masked))
        op_swapped_inputs = [
            inp if inp != op_input else op_inp_tens for inp in op_sgv.inputs
        ]
        graph_editor.swap_inputs(op, op_swapped_inputs)
    tf_compat.add_to_collection(
        PruningScope.collection_name(ks_group, PruningScope.OP_MASKED_VAR),
        masked)
    return PruningOpVars(op, op_input, None, mask, masked)
Beispiel #3
0
        def _calc_lr():
            steps = tf_compat.subtract(global_step, start_step)
            updates = tf_compat.cond(
                after,
                lambda: max_updates,
                lambda: tf_compat.cast(
                    tf_compat.floor(tf_compat.divide(steps, step_size)),
                    tf_compat.int64,
                ),
            )
            mult_g = tf_compat.pow(gamma,
                                   tf_compat.cast(updates, tf_compat.float32))

            return tf_compat.multiply(init_lr, mult_g)
Beispiel #4
0
def apply_op_vars_masks(pruning_op_vars: List[PruningOpVars], ks_group: str,
                        sess: tf_compat.Session):
    """
    Apply the masks to the original ops input var so that it can be saved
    with the desired sparsity for later.

    :param pruning_op_vars: the list of named tuples containing the sparse mask
        and the op variable to apply the sparse mask to
    :param ks_group: the group to create the assign ops under
    :param sess: the session to use to run the assign
    """
    for op_vars in pruning_op_vars:
        with tf_compat.name_scope(
                PruningScope.model(op_vars.op, ks_group,
                                   PruningScope.OP_SAVE)):
            masked_var = tf_compat.multiply(op_vars.op_input, op_vars.mask)
            input_var = get_tensor_var(op_vars.op_input)
            assign = tf_compat.assign(input_var, masked_var)
            sess.run(assign)
Beispiel #5
0
 def _update():
     # create the update ops using the target sparsity tensor
     with tf_compat.name_scope(
             PruningScope.model(
                 op,
                 ks_group,
                 additional=PruningScope.OPS_UPDATE,
                 trailing_slash=True,
             )):
         new_mask = mask_creator.create_sparsity_mask(op_var_tens, sparsity)
         weight_var = get_tensor_var(op_var_tens)
         return tf_compat.group(
             tf_compat.assign(mask,
                              new_mask,
                              name=PruningScope.OP_MASK_ASSIGN),
             tf_compat.assign(
                 weight_var,
                 tf_compat.multiply(new_mask, op_var_tens),
                 name=PruningScope.OP_WEIGHT_UPDATE,
             ),
         )
Beispiel #6
0
def create_ks_schedule_ops(
    global_step: tf_compat.Variable,
    begin_step: int,
    end_step: int,
    update_step_freq: int,
    init_sparsity: float,
    final_sparsity: float,
    exponent: float,
    ks_group: str,
) -> Tuple[tf_compat.Tensor, tf_compat.Tensor]:
    """
    Create a gradual schedule for model pruning (kernel sparsity).
    Creates a sparsity tensor that goes from init_sparsity til final_sparsity
    starting at begin_step and ending at end_step.
    Uses the global_step to map those.
    Additionally creates an update_ready tensor that is True if an update
    to the sparsity tensor should be run, False otherwise.

    :param global_step: the global optimizer step for the training graph
    :param begin_step: the global step to begin pruning at
    :param end_step: the global step to end pruning at
    :param update_step_freq: the number of global steps between each weight update
    :param init_sparsity: the starting value for sparsity of a
        weight tensor to be enforce
    :param final_sparsity: the end value for sparsity for a weight tensor to be enforce
    :param exponent: the exponent to use for interpolating between
        init_sparsity and final_sparsity higher values will lead to larger sparsity
        steps at the beginning vs the end ie: linear (1) vs cubic (3)
    :param ks_group: the group identifier the scope should be created under
    :return: a tuple containing the signal for update_ready and the target sparsity
    """

    # create the scheduling ops first and the sparsity ops
    with tf_compat.name_scope(
            PruningScope.general(ks_group,
                                 additional=PruningScope.OPS_SCHEDULE,
                                 trailing_slash=True)):
        sched_before = tf_compat.less(global_step, begin_step)
        sched_start = tf_compat.equal(global_step, begin_step)
        sched_end = tf_compat.equal(global_step, end_step)
        sched_active = tf_compat.logical_and(
            tf_compat.greater(global_step, begin_step),
            tf_compat.less(global_step, end_step),
        )
        sched_active_inclusive = tf_compat.logical_or(
            sched_active, tf_compat.logical_or(sched_start, sched_end))
        sched_update = tf_compat.cond(
            tf_compat.less_equal(update_step_freq, 0),
            lambda: tf_compat.constant(True),
            lambda: tf_compat.equal(
                tf_compat.mod(
                    (global_step - begin_step), update_step_freq), 0),
        )
        sched_update_ready = tf_compat.logical_or(
            tf_compat.logical_or(sched_start, sched_end), sched_update)

        percentage = tf_compat.minimum(
            1.0,
            tf_compat.maximum(
                0.0,
                tf_compat_div(
                    tf_compat.cast(global_step - begin_step,
                                   tf_compat.float32),
                    end_step - begin_step,
                ),
            ),
        )
        exp_percentage = 1 - tf_compat.pow(1 - percentage, exponent)
        calc_sparsity = (tf_compat.multiply(final_sparsity - init_sparsity,
                                            exp_percentage) + init_sparsity)

        # create the update ready tensor and sparsity tensor
    with tf_compat.name_scope(
            PruningScope.general(ks_group, trailing_slash=True)):
        update_ready = tf_compat.logical_and(
            sched_active_inclusive,
            sched_update_ready,
            name=PruningScope.OP_UPDATE_READY,
        )
        sparsity = tf_compat.case(
            [
                (sched_before, lambda: tf_compat.constant(0.0)),
                (sched_start, lambda: tf_compat.constant(init_sparsity)),
                (sched_active, lambda: calc_sparsity),
            ],
            default=lambda: tf_compat.constant(final_sparsity),
            name=PruningScope.OP_SPARSITY,
        )

        # add return state to collections
    tf_compat.add_to_collection(
        PruningScope.collection_name(ks_group, PruningScope.OP_UPDATE_READY),
        update_ready,
    )
    tf_compat.add_to_collection(
        PruningScope.collection_name(ks_group, PruningScope.OP_SPARSITY),
        sparsity)

    return update_ready, sparsity