Example #1
0
def train_function(pruning_method, loss, output_dir, use_tpu):
    """Training script for resnet model.

  Args:
   pruning_method: string indicating pruning method used to compress model.
   loss: tensor float32 of the cross entropy + regularization losses.
   output_dir: string tensor indicating the directory to save summaries.
   use_tpu: boolean indicating whether to run script on a tpu.

  Returns:
    host_call: summary tensors to be computed at each training step.
    train_op: the optimization term.
  """

    global_step = tf.train.get_global_step()

    steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
    current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)
    learning_rate = lr_schedule(current_epoch)
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                           momentum=FLAGS.momentum,
                                           use_nesterov=True)

    if use_tpu:
        # use CrossShardOptimizer when using TPU.
        optimizer = contrib_tpu.CrossShardOptimizer(optimizer)

    # UPDATE_OPS needs to be added as a dependency due to batch norm
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops), tf.name_scope('train'):
        train_op = optimizer.minimize(loss, global_step)

    if not use_tpu:
        if FLAGS.num_workers > 0:
            optimizer = tf.train.SyncReplicasOptimizer(
                optimizer,
                replicas_to_aggregate=FLAGS.num_workers,
                total_num_replicas=FLAGS.num_workers)
            optimizer.make_session_run_hook(True)

    metrics = {
        'global_step': tf.train.get_or_create_global_step(),
        'loss': loss,
        'learning_rate': learning_rate,
        'current_epoch': current_epoch
    }

    if pruning_method == 'threshold':
        # construct the necessary hparams string from the FLAGS
        hparams_string = ('begin_pruning_step={0},'
                          'sparsity_function_begin_step={0},'
                          'end_pruning_step={1},'
                          'sparsity_function_end_step={1},'
                          'target_sparsity={2},'
                          'pruning_frequency={3},'
                          'threshold_decay=0,'
                          'use_tpu={4}'.format(
                              FLAGS.sparsity_begin_step,
                              FLAGS.sparsity_end_step,
                              FLAGS.end_sparsity,
                              FLAGS.pruning_frequency,
                              FLAGS.use_tpu,
                          ))

        # Parse pruning hyperparameters
        pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)

        # The first layer has so few parameters, we don't need to prune it, and
        # pruning it a higher sparsity levels has very negative effects.
        if FLAGS.prune_first_layer and FLAGS.first_layer_sparsity >= 0.:
            pruning_hparams.set_hparam(
                'weight_sparsity_map',
                ['resnet_model/initial_conv:%f' % FLAGS.first_layer_sparsity])
        if FLAGS.prune_last_layer and FLAGS.last_layer_sparsity >= 0:
            pruning_hparams.set_hparam(
                'weight_sparsity_map',
                ['resnet_model/final_dense:%f' % FLAGS.last_layer_sparsity])

        # Create a pruning object using the pruning hyperparameters
        pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

        # We override the train op to also update the mask.
        with tf.control_dependencies([train_op]):
            train_op = pruning_obj.conditional_mask_update_op()

        masks = pruning.get_masks()
        metrics.update(utils.mask_summaries(masks))
    elif pruning_method == 'scratch':
        masks = pruning.get_masks()
        # make sure the masks have the sparsity we expect and that it doesn't change
        metrics.update(utils.mask_summaries(masks))
    elif pruning_method == 'variational_dropout':
        masks = utils.add_vd_pruning_summaries(
            threshold=FLAGS.log_alpha_threshold)
        metrics.update(masks)
    elif pruning_method == 'l0_regularization':
        summaries = utils.add_l0_summaries()
        metrics.update(summaries)
    elif pruning_method == 'baseline':
        pass
    else:
        raise ValueError('Unsupported pruning method', FLAGS.pruning_method)

    host_call = (functools.partial(utils.host_call_fn,
                                   output_dir), utils.format_tensors(metrics))

    return host_call, train_op
  def testMultipleConvMaskAdded(self, pruning_method):

    tf.reset_default_graph()
    g = tf.Graph()
    with g.as_default():
      number_of_layers = 5

      kernel_size = [3, 3]
      base_depth = 4
      depth_step = 7

      input_tensor = tf.ones((8, self.height, self.width, base_depth))

      top_layer = input_tensor

      for ix in range(number_of_layers):
        units = base_depth + (ix + 1) * depth_step
        top_layer = pruning_layers.sparse_conv2d(
            x=top_layer,
            units=units,
            kernel_size=kernel_size,
            is_training=False,
            sparsity_technique=pruning_method)

      if pruning_method == 'variational_dropout':
        theta_logsigma2 = tf.get_collection(
            vd.layers.THETA_LOGSIGMA2_COLLECTION)
        self.assertLen(theta_logsigma2, number_of_layers)

        utils.add_vd_pruning_summaries(theta_logsigma2, threshold=3.0)

        dkl_loss_1 = utils.variational_dropout_dkl_loss(
            reg_scalar=1,
            start_reg_ramp_up=0,
            end_reg_ramp_up=1000,
            warm_up=False,
            use_tpu=False)
        dkl_loss_1 = tf.reshape(dkl_loss_1, [1])

        dkl_loss_2 = utils.variational_dropout_dkl_loss(
            reg_scalar=5,
            start_reg_ramp_up=0,
            end_reg_ramp_up=1000,
            warm_up=False,
            use_tpu=False)
        dkl_loss_2 = tf.reshape(dkl_loss_2, [1])

        for ix in range(number_of_layers):
          self.assertListEqual(theta_logsigma2[ix][0].get_shape().as_list(), [
              kernel_size[0], kernel_size[1], base_depth + ix * depth_step,
              base_depth + (ix + 1) * depth_step
          ])

        init_op = tf.global_variables_initializer()

        with self.test_session() as sess:
          sess.run(init_op)
          if pruning_method == 'variational_dropout':
            loss_1, loss_2 = sess.run([dkl_loss_1, dkl_loss_2])

            self.assertGreater(loss_2, loss_1)
      elif pruning_method == 'l0_regularization':
        theta_logalpha = tf.get_collection(
            l0.layers.THETA_LOGALPHA_COLLECTION)
        self.assertLen(theta_logalpha, number_of_layers)

        utils.add_l0_summaries(theta_logalpha)

        l0_norm_loss_1 = utils.l0_regularization_loss(
            reg_scalar=1,
            start_reg_ramp_up=0,
            end_reg_ramp_up=1000,
            warm_up=False,
            use_tpu=False)
        l0_norm_loss_1 = tf.reshape(l0_norm_loss_1, [1])

        l0_norm_loss_2 = utils.l0_regularization_loss(
            reg_scalar=5,
            start_reg_ramp_up=0,
            end_reg_ramp_up=1000,
            warm_up=False,
            use_tpu=False)
        l0_norm_loss_2 = tf.reshape(l0_norm_loss_2, [1])

        for ix in range(number_of_layers):
          self.assertListEqual(theta_logalpha[ix][0].get_shape().as_list(), [
              kernel_size[0], kernel_size[1], base_depth + ix * depth_step,
              base_depth + (ix + 1) * depth_step
          ])

        init_op = tf.global_variables_initializer()

        with self.test_session() as sess:
          sess.run(init_op)
          loss_1, loss_2 = sess.run([l0_norm_loss_1, l0_norm_loss_2])
          self.assertGreater(loss_2, loss_1)
      else:
        mask = tf.get_collection(core.MASK_COLLECTION)
        for ix in range(number_of_layers):
          self.assertListEqual(mask[ix].get_shape().as_list(), [
              kernel_size[0], kernel_size[1], base_depth + ix * depth_step,
              base_depth + (ix + 1) * depth_step
          ])