def _thresh_fn(variable, mask): thresh_fn = percentile_thresh_fn(1 - target_sparsities[variable.op.name], target_iterations, update_steps, thresh_lower_scale=1) return thresh_fn(variable, mask)
def train_op_fn(loss, params): prob_thresh = _prob_piecewise() optimizer = tf.train.MomentumOptimizer(learning_rate=_learning_rate_inverse, momentum=0.9) weights = tf.get_collection(tf.GraphKeys.WEIGHTS) thresh_fn = percentile_thresh_fn( _make_target_sparsity(0.0, 0.60, 0.15, depth_multiplier), target_iterations=200000, update_steps=2000) train_op = make_dns_train_op( loss, optimizer=optimizer, prob_thresh=prob_thresh, thresh_fn_or_scale=thresh_fn, variables=weights, global_step=tf.train.get_or_create_global_step()) return train_op
def train_op_fn(loss, params): prob_thresh = tf.train.inverse_time_decay( 1.0, tf.train.get_or_create_global_step(), decay_steps=100, decay_rate=0.2, staircase=True, name='prob_thresh') optimizer = tf.train.MomentumOptimizer( learning_rate=_learning_rate, momentum=0.9) weights = tf.get_collection(tf.GraphKeys.WEIGHTS) thresh_fn = percentile_thresh_fn( target_sparsity, target_iterations=15000, update_steps=500) train_op = make_dns_train_op( loss, optimizer=optimizer, prob_thresh=prob_thresh, thresh_fn_or_scale=thresh_fn, variables=weights, global_step=tf.train.get_or_create_global_step()) return train_op