def testMaskedLSTMCell(self): expected_num_masks = 1 expected_num_rows = 2 * self.dim expected_num_cols = 4 * self.dim with self.test_session(): inputs = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) c = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) h = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) state = tf_rnn_cells.LSTMStateTuple(c, h) lstm_cell = rnn_cells.MaskedLSTMCell(self.dim) lstm_cell(inputs, state) self.assertEqual(len(pruning.get_masks()), expected_num_masks) self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks) self.assertEqual(len(pruning.get_thresholds()), expected_num_masks) self.assertEqual(len(pruning.get_weights()), expected_num_masks) for mask in pruning.get_masks(): self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols)) for weight in pruning.get_weights(): self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2): """Setups a trivial training procedure for sparse training.""" tf.reset_default_graph() optim = tf.train.GradientDescentOptimizer(1e-3) global_step = tf.train.get_or_create_global_step() sparse_optim = sparse_optimizers.SparseRigLOptimizer( optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac) x = tf.ones((1, n_inp)) y = layers.masked_fully_connected(x, n_out, activation_fn=None) # Multiplying the output with range of constants to have constant but # different gradients at the masked weights. We also multiply the loss with # global_step to increase the gradient linearly with time. scale_vector = ( tf.reshape(tf.cast(tf.range(tf.size(y)), dtype=y.dtype), y.shape) * tf.cast(global_step, dtype=y.dtype)) y = y * scale_vector loss = tf.reduce_sum(y) global_step = tf.train.get_or_create_global_step() train_op = sparse_optim.minimize(loss, global_step) weight = pruning.get_weights()[0] expected_gradient = tf.broadcast_to(scale_vector, weight.shape) masked_grad = sparse_optim._weight2masked_grads[weight.name] # Init sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) return sess, train_op, masked_grad, expected_gradient
def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2): """Setups a trivial training procedure for sparse training.""" tf.reset_default_graph() optim = tf.train.GradientDescentOptimizer(0.1) sparse_optim = sparse_optimizers.SparseSETOptimizer( optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac) x = tf.random.uniform((1, n_inp)) y = layers.masked_fully_connected(x, n_out, activation_fn=None) global_step = tf.train.get_or_create_global_step() weight = pruning.get_weights()[0] # There is one masked layer to be trained. mask = pruning.get_masks()[0] # Around half of the values of the mask is set to zero with `mask_update`. mask_update = tf.assign( mask, tf.constant( np.random.choice([0, 1], size=(n_inp, n_out), p=[1./2, 1./2]), dtype=tf.float32)) loss = tf.reduce_mean(y) global_step = tf.train.get_or_create_global_step() train_op = sparse_optim.minimize(loss, global_step) # Init sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) sess.run([mask_update]) return sess, train_op, mask, weight, global_step
def _setup_graph(self, default_sparsity, mask_init_method, custom_sparsity_map, n_inp=3, n_out=5): """Setups a trivial training procedure for sparse training.""" tf.reset_default_graph() optim = tf.train.GradientDescentOptimizer(1e-3) sparse_optim = sparse_optimizers.SparseSnipOptimizer( optim, default_sparsity, mask_init_method, custom_sparsity_map=custom_sparsity_map) inp_values = np.arange(1, n_inp+1) scale_vector_values = np.random.uniform(size=(n_out,)) - 0.5 # The gradient is the outer product of input and the output gradients. # Since the loss is sample sum the output gradient is equal to the scale # vector. expected_grads = np.outer(inp_values, scale_vector_values) x = tf.reshape(tf.constant(inp_values, dtype=tf.float32), (1, n_inp)) y = layers.masked_fully_connected(x, n_out, activation_fn=None) scale_vector = tf.constant(scale_vector_values, dtype=tf.float32) y = y * scale_vector loss = tf.reduce_sum(y) global_step = tf.train.get_or_create_global_step() train_op = sparse_optim.minimize(loss, global_step) # Init sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) mask = pruning.get_masks()[0] weights = pruning.get_weights()[0] return sess, train_op, expected_grads, sparse_optim, mask, weights
def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2, momentum=0.5): """Setups a trivial training procedure for sparse training.""" tf.reset_default_graph() optim = tf.train.GradientDescentOptimizer(0.1) sparse_optim = sparse_optimizers.SparseMomentumOptimizer( optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac, momentum=momentum) x = tf.ones((1, n_inp)) y = layers.masked_fully_connected(x, n_out, activation_fn=None) # Multiplying the output with range of constants to have constant but # different gradients at the masked weights. y = y * tf.reshape(tf.cast(tf.range(tf.size(y)), dtype=y.dtype), y.shape) loss = tf.reduce_sum(y) global_step = tf.train.get_or_create_global_step() train_op = sparse_optim.minimize(loss, global_step) weight = pruning.get_weights()[0] masked_grad = sparse_optim._weight2masked_grads[weight.name] masked_grad_ema = sparse_optim._ema_grads.average(masked_grad) # Init sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) return sess, train_op, masked_grad_ema
def testMaskedLSTMCell(self): expected_num_masks = 1 expected_num_rows = 2 * self.dim expected_num_cols = 4 * self.dim with self.cached_session(): inputs = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) c = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) h = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) state = tf_rnn_cells.LSTMStateTuple(c, h) lstm_cell = rnn_cells.MaskedLSTMCell(self.dim) lstm_cell(inputs, state) self.assertEqual(len(pruning.get_masks()), expected_num_masks) self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks) self.assertEqual(len(pruning.get_thresholds()), expected_num_masks) self.assertEqual(len(pruning.get_weights()), expected_num_masks) for mask in pruning.get_masks(): self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols)) for weight in pruning.get_weights(): self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
def train_function(training_method, loss, cross_loss, reg_loss, output_dir, use_tpu): """Training script for resnet model. Args: training_method: string indicating pruning method used to compress model. loss: tensor float32 of the cross entropy + regularization losses. cross_loss: tensor, only cross entropy loss, passed for logging. reg_loss: tensor, only regularization loss, passed for logging. output_dir: string tensor indicating the directory to save summaries. use_tpu: boolean indicating whether to run script on a tpu. Returns: host_call: summary tensors to be computed at each training step. train_op: the optimization term. """ global_step = tf.train.get_global_step() steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch) learning_rate = lr_schedule(current_epoch) if FLAGS.use_adam: # We don't use step decrease for the learning rate. learning_rate = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) else: optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=FLAGS.momentum, use_nesterov=True) if use_tpu: # use CrossShardOptimizer when using TPU. optimizer = contrib_tpu.CrossShardOptimizer(optimizer) if training_method == 'set': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseSETOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal, stateless_seed_offset=FLAGS.seed) elif training_method == 'static': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseStaticOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal, stateless_seed_offset=FLAGS.seed) elif training_method == 'momentum': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseMomentumOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, grow_init=FLAGS.grow_init, stateless_seed_offset=FLAGS.seed, drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=use_tpu) elif training_method == 'rigl': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseRigLOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, stateless_seed_offset=FLAGS.seed, drop_fraction_anneal=FLAGS.drop_fraction_anneal, initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=use_tpu) elif training_method == 'snip': optimizer = sparse_optimizers.SparseSnipOptimizer( optimizer, mask_init_method=FLAGS.mask_init_method, custom_sparsity_map=CUSTOM_SPARSITY_MAP, default_sparsity=FLAGS.end_sparsity, use_tpu=use_tpu) elif training_method in ('scratch', 'baseline'): pass else: raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method) # UPDATE_OPS needs to be added as a dependency due to batch norm update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops), tf.name_scope('train'): train_op = optimizer.minimize(loss, global_step) metrics = { 'global_step': tf.train.get_or_create_global_step(), 'loss': loss, 'cross_loss': cross_loss, 'reg_loss': reg_loss, 'learning_rate': learning_rate, 'current_epoch': current_epoch, } # Logging drop_fraction if dynamic sparse training. if training_method in ('set', 'momentum', 'rigl', 'static'): metrics['drop_fraction'] = optimizer.drop_fraction # Let's log some statistics from a single parameter-mask couple. # This is useful for debugging. test_var = pruning.get_weights()[0] test_var_mask = pruning.get_masks()[0] metrics.update({ 'fw_nz_weight': tf.count_nonzero(test_var), 'fw_nz_mask': tf.count_nonzero(test_var_mask), 'fw_l1_weight': tf.reduce_sum(tf.abs(test_var)) }) masks = pruning.get_masks() global_sparsity = sparse_utils.calculate_sparsity(masks) metrics['global_sparsity'] = global_sparsity metrics.update( utils.mask_summaries(masks[:4] + masks[-1:], with_img=FLAGS.log_mask_imgs_each_iteration)) host_call = (functools.partial(utils.host_call_fn, output_dir), utils.format_tensors(metrics)) return host_call, train_op
def get_weights(self): return pruning.get_weights()