def _train_op(loss, params): global_step = tf.train.get_or_create_global_step() optimizer = tf.train.MomentumOptimizer(learning_rate=1e-4, momentum=0.9) weights = tf.get_collection(MASKED_WEIGHT_COLLECTION) total_nonzero = tf.add_n([tf.count_nonzero(w) for w in weights]) total_length = tf.add_n([tf.size(w) for w in weights]) prune_amount = tf.cast(total_nonzero, tf.float32) / tf.cast( total_length, tf.float32) prune_amount = tf.minimum(prune_amount, 0.99) prune_penalization = 1 / tf.sqrt((1 - prune_amount)) regularization_loss = tf.identity( prune_penalization * params.l2_pen * tf.add_n([tf.nn.l2_loss(w) for w in weights]), name='pruned_regularization_loss') loss = loss + regularization_loss tf.summary.scalar('pruned_total_loss', loss) train_op = make_dns_train_op(loss, optimizer=optimizer, thresh_fn_or_scale=_make_thresh_fn( 20000, 200), prob_thresh=_prob_piecewise(), global_step=global_step) return train_op
def train_op_fn(loss, params): prob_thresh = _prob_piecewise() optimizer = tf.train.MomentumOptimizer(learning_rate=_learning_rate_inverse, momentum=0.9) weights = tf.get_collection(tf.GraphKeys.WEIGHTS) thresh_fn = percentile_thresh_fn( _make_target_sparsity(0.0, 0.60, 0.15, depth_multiplier), target_iterations=200000, update_steps=2000) train_op = make_dns_train_op( loss, optimizer=optimizer, prob_thresh=prob_thresh, thresh_fn_or_scale=thresh_fn, variables=weights, global_step=tf.train.get_or_create_global_step()) return train_op
def train_op_fn(loss, params): prob_thresh = tf.train.inverse_time_decay( 1.0, tf.train.get_or_create_global_step(), decay_steps=100, decay_rate=0.2, staircase=True, name='prob_thresh') optimizer = tf.train.MomentumOptimizer( learning_rate=_learning_rate, momentum=0.9) weights = tf.get_collection(tf.GraphKeys.WEIGHTS) thresh_fn = percentile_thresh_fn( target_sparsity, target_iterations=15000, update_steps=500) train_op = make_dns_train_op( loss, optimizer=optimizer, prob_thresh=prob_thresh, thresh_fn_or_scale=thresh_fn, variables=weights, global_step=tf.train.get_or_create_global_step()) return train_op