Example #1
0
    def testMaskedLSTMCell(self):
        expected_num_masks = 1
        expected_num_rows = 2 * self.dim
        expected_num_cols = 4 * self.dim
        with self.test_session():
            inputs = variables.Variable(
                random_ops.random_normal([self.batch_size, self.dim]))
            c = variables.Variable(
                random_ops.random_normal([self.batch_size, self.dim]))
            h = variables.Variable(
                random_ops.random_normal([self.batch_size, self.dim]))
            state = tf_rnn_cells.LSTMStateTuple(c, h)
            lstm_cell = rnn_cells.MaskedLSTMCell(self.dim)
            lstm_cell(inputs, state)
            self.assertEqual(len(pruning.get_masks()), expected_num_masks)
            self.assertEqual(len(pruning.get_masked_weights()),
                             expected_num_masks)
            self.assertEqual(len(pruning.get_thresholds()), expected_num_masks)
            self.assertEqual(len(pruning.get_weights()), expected_num_masks)

            for mask in pruning.get_masks():
                self.assertEqual(mask.shape,
                                 (expected_num_rows, expected_num_cols))
            for weight in pruning.get_weights():
                self.assertEqual(weight.shape,
                                 (expected_num_rows, expected_num_cols))
    def _setup_graph(self,
                     n_inp,
                     n_out,
                     drop_frac,
                     start_iter=1,
                     end_iter=4,
                     freq_iter=2):
        """Setups a trivial training procedure for sparse training."""
        tf.reset_default_graph()
        optim = tf.train.GradientDescentOptimizer(1e-3)
        global_step = tf.train.get_or_create_global_step()
        sparse_optim = sparse_optimizers.SparseRigLOptimizer(
            optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac)
        x = tf.ones((1, n_inp))
        y = layers.masked_fully_connected(x, n_out, activation_fn=None)
        # Multiplying the output with range of constants to have constant but
        # different gradients at the masked weights. We also multiply the loss with
        # global_step to increase the gradient linearly with time.
        scale_vector = (
            tf.reshape(tf.cast(tf.range(tf.size(y)), dtype=y.dtype), y.shape) *
            tf.cast(global_step, dtype=y.dtype))
        y = y * scale_vector
        loss = tf.reduce_sum(y)
        global_step = tf.train.get_or_create_global_step()
        train_op = sparse_optim.minimize(loss, global_step)
        weight = pruning.get_weights()[0]
        expected_gradient = tf.broadcast_to(scale_vector, weight.shape)
        masked_grad = sparse_optim._weight2masked_grads[weight.name]

        # Init
        sess = tf.Session()
        init = tf.global_variables_initializer()
        sess.run(init)

        return sess, train_op, masked_grad, expected_gradient
  def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,
                   freq_iter=2):
    """Setups a trivial training procedure for sparse training."""
    tf.reset_default_graph()
    optim = tf.train.GradientDescentOptimizer(0.1)
    sparse_optim = sparse_optimizers.SparseSETOptimizer(
        optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac)
    x = tf.random.uniform((1, n_inp))
    y = layers.masked_fully_connected(x, n_out, activation_fn=None)
    global_step = tf.train.get_or_create_global_step()
    weight = pruning.get_weights()[0]
    # There is one masked layer to be trained.
    mask = pruning.get_masks()[0]
    # Around half of the values of the mask is set to zero with `mask_update`.
    mask_update = tf.assign(
        mask,
        tf.constant(
            np.random.choice([0, 1], size=(n_inp, n_out), p=[1./2, 1./2]),
            dtype=tf.float32))
    loss = tf.reduce_mean(y)
    global_step = tf.train.get_or_create_global_step()
    train_op = sparse_optim.minimize(loss, global_step)

    # Init
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)
    sess.run([mask_update])

    return sess, train_op, mask, weight, global_step
  def _setup_graph(self, default_sparsity, mask_init_method,
                   custom_sparsity_map, n_inp=3, n_out=5):
    """Setups a trivial training procedure for sparse training."""
    tf.reset_default_graph()
    optim = tf.train.GradientDescentOptimizer(1e-3)
    sparse_optim = sparse_optimizers.SparseSnipOptimizer(
        optim, default_sparsity, mask_init_method,
        custom_sparsity_map=custom_sparsity_map)

    inp_values = np.arange(1, n_inp+1)
    scale_vector_values = np.random.uniform(size=(n_out,)) - 0.5
    # The gradient is the outer product of input and the output gradients.
    # Since the loss is sample sum the output gradient is equal to the scale
    # vector.
    expected_grads = np.outer(inp_values, scale_vector_values)

    x = tf.reshape(tf.constant(inp_values, dtype=tf.float32), (1, n_inp))
    y = layers.masked_fully_connected(x, n_out, activation_fn=None)
    scale_vector = tf.constant(scale_vector_values, dtype=tf.float32)

    y = y * scale_vector
    loss = tf.reduce_sum(y)

    global_step = tf.train.get_or_create_global_step()
    train_op = sparse_optim.minimize(loss, global_step)

    # Init
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)
    mask = pruning.get_masks()[0]
    weights = pruning.get_weights()[0]
    return sess, train_op, expected_grads, sparse_optim, mask, weights
  def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,
                   freq_iter=2, momentum=0.5):
    """Setups a trivial training procedure for sparse training."""
    tf.reset_default_graph()
    optim = tf.train.GradientDescentOptimizer(0.1)
    sparse_optim = sparse_optimizers.SparseMomentumOptimizer(
        optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac,
        momentum=momentum)
    x = tf.ones((1, n_inp))
    y = layers.masked_fully_connected(x, n_out, activation_fn=None)
    # Multiplying the output with range of constants to have constant but
    # different gradients at the masked weights.
    y = y * tf.reshape(tf.cast(tf.range(tf.size(y)), dtype=y.dtype), y.shape)
    loss = tf.reduce_sum(y)
    global_step = tf.train.get_or_create_global_step()
    train_op = sparse_optim.minimize(loss, global_step)
    weight = pruning.get_weights()[0]
    masked_grad = sparse_optim._weight2masked_grads[weight.name]
    masked_grad_ema = sparse_optim._ema_grads.average(masked_grad)
    # Init
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)

    return sess, train_op, masked_grad_ema
Example #6
0
  def testMaskedLSTMCell(self):
    expected_num_masks = 1
    expected_num_rows = 2 * self.dim
    expected_num_cols = 4 * self.dim
    with self.cached_session():
      inputs = variables.Variable(
          random_ops.random_normal([self.batch_size, self.dim]))
      c = variables.Variable(
          random_ops.random_normal([self.batch_size, self.dim]))
      h = variables.Variable(
          random_ops.random_normal([self.batch_size, self.dim]))
      state = tf_rnn_cells.LSTMStateTuple(c, h)
      lstm_cell = rnn_cells.MaskedLSTMCell(self.dim)
      lstm_cell(inputs, state)
      self.assertEqual(len(pruning.get_masks()), expected_num_masks)
      self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_thresholds()), expected_num_masks)
      self.assertEqual(len(pruning.get_weights()), expected_num_masks)

      for mask in pruning.get_masks():
        self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols))
      for weight in pruning.get_weights():
        self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
Example #7
0
def train_function(training_method, loss, cross_loss, reg_loss, output_dir,
                   use_tpu):
    """Training script for resnet model.

  Args:
   training_method: string indicating pruning method used to compress model.
   loss: tensor float32 of the cross entropy + regularization losses.
   cross_loss: tensor, only cross entropy loss, passed for logging.
   reg_loss: tensor, only regularization loss, passed for logging.
   output_dir: string tensor indicating the directory to save summaries.
   use_tpu: boolean indicating whether to run script on a tpu.

  Returns:
    host_call: summary tensors to be computed at each training step.
    train_op: the optimization term.
  """

    global_step = tf.train.get_global_step()

    steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
    current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)
    learning_rate = lr_schedule(current_epoch)
    if FLAGS.use_adam:
        # We don't use step decrease for the learning rate.
        learning_rate = FLAGS.base_learning_rate * (FLAGS.train_batch_size /
                                                    256.0)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    else:
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=FLAGS.momentum,
                                               use_nesterov=True)

    if use_tpu:
        # use CrossShardOptimizer when using TPU.
        optimizer = contrib_tpu.CrossShardOptimizer(optimizer)

    if training_method == 'set':
        # We override the train op to also update the mask.
        optimizer = sparse_optimizers.SparseSETOptimizer(
            optimizer,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            stateless_seed_offset=FLAGS.seed)
    elif training_method == 'static':
        # We override the train op to also update the mask.
        optimizer = sparse_optimizers.SparseStaticOptimizer(
            optimizer,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            stateless_seed_offset=FLAGS.seed)
    elif training_method == 'momentum':
        # We override the train op to also update the mask.
        optimizer = sparse_optimizers.SparseMomentumOptimizer(
            optimizer,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            momentum=FLAGS.s_momentum,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            grow_init=FLAGS.grow_init,
            stateless_seed_offset=FLAGS.seed,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            use_tpu=use_tpu)
    elif training_method == 'rigl':
        # We override the train op to also update the mask.
        optimizer = sparse_optimizers.SparseRigLOptimizer(
            optimizer,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            stateless_seed_offset=FLAGS.seed,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            initial_acc_scale=FLAGS.rigl_acc_scale,
            use_tpu=use_tpu)
    elif training_method == 'snip':
        optimizer = sparse_optimizers.SparseSnipOptimizer(
            optimizer,
            mask_init_method=FLAGS.mask_init_method,
            custom_sparsity_map=CUSTOM_SPARSITY_MAP,
            default_sparsity=FLAGS.end_sparsity,
            use_tpu=use_tpu)
    elif training_method in ('scratch', 'baseline'):
        pass
    else:
        raise ValueError('Unsupported pruning method: %s' %
                         FLAGS.training_method)
    # UPDATE_OPS needs to be added as a dependency due to batch norm
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops), tf.name_scope('train'):
        train_op = optimizer.minimize(loss, global_step)

    metrics = {
        'global_step': tf.train.get_or_create_global_step(),
        'loss': loss,
        'cross_loss': cross_loss,
        'reg_loss': reg_loss,
        'learning_rate': learning_rate,
        'current_epoch': current_epoch,
    }

    # Logging drop_fraction if dynamic sparse training.
    if training_method in ('set', 'momentum', 'rigl', 'static'):
        metrics['drop_fraction'] = optimizer.drop_fraction

    # Let's log some statistics from a single parameter-mask couple.
    # This is useful for debugging.
    test_var = pruning.get_weights()[0]
    test_var_mask = pruning.get_masks()[0]
    metrics.update({
        'fw_nz_weight': tf.count_nonzero(test_var),
        'fw_nz_mask': tf.count_nonzero(test_var_mask),
        'fw_l1_weight': tf.reduce_sum(tf.abs(test_var))
    })

    masks = pruning.get_masks()
    global_sparsity = sparse_utils.calculate_sparsity(masks)
    metrics['global_sparsity'] = global_sparsity
    metrics.update(
        utils.mask_summaries(masks[:4] + masks[-1:],
                             with_img=FLAGS.log_mask_imgs_each_iteration))

    host_call = (functools.partial(utils.host_call_fn,
                                   output_dir), utils.format_tensors(metrics))

    return host_call, train_op
Example #8
0
 def get_weights(self):
     return pruning.get_weights()