Пример #1
0
def _train_op(loss, params):
    global_step = tf.train.get_or_create_global_step()
    optimizer = tf.train.MomentumOptimizer(learning_rate=1e-4, momentum=0.9)

    weights = tf.get_collection(MASKED_WEIGHT_COLLECTION)

    total_nonzero = tf.add_n([tf.count_nonzero(w) for w in weights])
    total_length = tf.add_n([tf.size(w) for w in weights])

    prune_amount = tf.cast(total_nonzero, tf.float32) / tf.cast(
        total_length, tf.float32)
    prune_amount = tf.minimum(prune_amount, 0.99)
    prune_penalization = 1 / tf.sqrt((1 - prune_amount))

    regularization_loss = tf.identity(
        prune_penalization * params.l2_pen *
        tf.add_n([tf.nn.l2_loss(w) for w in weights]),
        name='pruned_regularization_loss')

    loss = loss + regularization_loss
    tf.summary.scalar('pruned_total_loss', loss)

    train_op = make_dns_train_op(loss,
                                 optimizer=optimizer,
                                 thresh_fn_or_scale=_make_thresh_fn(
                                     20000, 200),
                                 prob_thresh=_prob_piecewise(),
                                 global_step=global_step)

    return train_op
    def train_op_fn(loss, params):
        prob_thresh = _prob_piecewise()

        optimizer = tf.train.MomentumOptimizer(learning_rate=_learning_rate_inverse, momentum=0.9)
        weights = tf.get_collection(tf.GraphKeys.WEIGHTS)

        thresh_fn = percentile_thresh_fn(
            _make_target_sparsity(0.0, 0.60, 0.15, depth_multiplier),
            target_iterations=200000,
            update_steps=2000)

        train_op = make_dns_train_op(
            loss, optimizer=optimizer, prob_thresh=prob_thresh, thresh_fn_or_scale=thresh_fn,
            variables=weights, global_step=tf.train.get_or_create_global_step())

        return train_op
Пример #3
0
    def train_op_fn(loss, params):
        prob_thresh = tf.train.inverse_time_decay(
            1.0, tf.train.get_or_create_global_step(),
            decay_steps=100, decay_rate=0.2,
            staircase=True, name='prob_thresh')

        optimizer = tf.train.MomentumOptimizer(
            learning_rate=_learning_rate, momentum=0.9)

        weights = tf.get_collection(tf.GraphKeys.WEIGHTS)

        thresh_fn = percentile_thresh_fn(
            target_sparsity,
            target_iterations=15000,
            update_steps=500)

        train_op = make_dns_train_op(
            loss, optimizer=optimizer, prob_thresh=prob_thresh, thresh_fn_or_scale=thresh_fn,
            variables=weights, global_step=tf.train.get_or_create_global_step())

        return train_op