示例#1
0
    def metric_fn(labels, logits, cross_loss, reg_loss):
      """Calculate eval metrics."""
      logging.info('In metric function')
      eval_metrics = {}
      predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)
      in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
      eval_metrics['top_5_eval_accuracy'] = tf.metrics.mean(in_top_5)
      eval_metrics['cross_loss'] = tf.metrics.mean(cross_loss)
      eval_metrics['reg_loss'] = tf.metrics.mean(reg_loss)
      eval_metrics['eval_accuracy'] = tf.metrics.accuracy(
          labels=labels, predictions=predictions)

      # If evaluating once lets also calculate sparsities.
      if FLAGS.mode == 'eval_once':
        sparsity_summaries = utils.mask_summaries(pruning.get_masks())
        # We call mean on a scalar to create tensor, update_op pairs.
        sparsity_summaries = {k: tf.metrics.mean(v) for k, v
                              in sparsity_summaries.items()}
        eval_metrics.update(sparsity_summaries)
      return eval_metrics
示例#2
0
def train_function(training_method, loss, cross_loss, reg_loss, output_dir,
                   use_tpu):
    """Training script for resnet model.

  Args:
   training_method: string indicating pruning method used to compress model.
   loss: tensor float32 of the cross entropy + regularization losses.
   cross_loss: tensor, only cross entropy loss, passed for logging.
   reg_loss: tensor, only regularization loss, passed for logging.
   output_dir: string tensor indicating the directory to save summaries.
   use_tpu: boolean indicating whether to run script on a tpu.

  Returns:
    host_call: summary tensors to be computed at each training step.
    train_op: the optimization term.
  """

    global_step = tf.train.get_global_step()

    steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
    current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)
    learning_rate = lr_schedule(current_epoch)
    if FLAGS.use_adam:
        # We don't use step decrease for the learning rate.
        learning_rate = FLAGS.base_learning_rate * (FLAGS.train_batch_size /
                                                    256.0)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    else:
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=FLAGS.momentum,
                                               use_nesterov=True)

    if use_tpu:
        # use CrossShardOptimizer when using TPU.
        optimizer = contrib_tpu.CrossShardOptimizer(optimizer)

    if training_method == 'set':
        # We override the train op to also update the mask.
        optimizer = sparse_optimizers.SparseSETOptimizer(
            optimizer,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            stateless_seed_offset=FLAGS.seed)
    elif training_method == 'static':
        # We override the train op to also update the mask.
        optimizer = sparse_optimizers.SparseStaticOptimizer(
            optimizer,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            stateless_seed_offset=FLAGS.seed)
    elif training_method == 'momentum':
        # We override the train op to also update the mask.
        optimizer = sparse_optimizers.SparseMomentumOptimizer(
            optimizer,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            momentum=FLAGS.s_momentum,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            grow_init=FLAGS.grow_init,
            stateless_seed_offset=FLAGS.seed,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            use_tpu=use_tpu)
    elif training_method == 'rigl':
        # We override the train op to also update the mask.
        optimizer = sparse_optimizers.SparseRigLOptimizer(
            optimizer,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            stateless_seed_offset=FLAGS.seed,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            initial_acc_scale=FLAGS.rigl_acc_scale,
            use_tpu=use_tpu)
    elif training_method == 'snip':
        optimizer = sparse_optimizers.SparseSnipOptimizer(
            optimizer,
            mask_init_method=FLAGS.mask_init_method,
            custom_sparsity_map=CUSTOM_SPARSITY_MAP,
            default_sparsity=FLAGS.end_sparsity,
            use_tpu=use_tpu)
    elif training_method in ('scratch', 'baseline'):
        pass
    else:
        raise ValueError('Unsupported pruning method: %s' %
                         FLAGS.training_method)
    # UPDATE_OPS needs to be added as a dependency due to batch norm
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops), tf.name_scope('train'):
        train_op = optimizer.minimize(loss, global_step)

    metrics = {
        'global_step': tf.train.get_or_create_global_step(),
        'loss': loss,
        'cross_loss': cross_loss,
        'reg_loss': reg_loss,
        'learning_rate': learning_rate,
        'current_epoch': current_epoch,
    }

    # Logging drop_fraction if dynamic sparse training.
    if training_method in ('set', 'momentum', 'rigl', 'static'):
        metrics['drop_fraction'] = optimizer.drop_fraction

    # Let's log some statistics from a single parameter-mask couple.
    # This is useful for debugging.
    test_var = pruning.get_weights()[0]
    test_var_mask = pruning.get_masks()[0]
    metrics.update({
        'fw_nz_weight': tf.count_nonzero(test_var),
        'fw_nz_mask': tf.count_nonzero(test_var_mask),
        'fw_l1_weight': tf.reduce_sum(tf.abs(test_var))
    })

    masks = pruning.get_masks()
    global_sparsity = sparse_utils.calculate_sparsity(masks)
    metrics['global_sparsity'] = global_sparsity
    metrics.update(
        utils.mask_summaries(masks[:4] + masks[-1:],
                             with_img=FLAGS.log_mask_imgs_each_iteration))

    host_call = (functools.partial(utils.host_call_fn,
                                   output_dir), utils.format_tensors(metrics))

    return host_call, train_op
示例#3
0
def train_fn(training_method, global_step, total_loss, train_dir, accuracy,
             top_5_accuracy):
  """Training script for resnet model.

  Args:
   training_method: specifies the method used to sparsify networks.
   global_step: the current step of training/eval.
   total_loss: tensor float32 of the cross entropy + regularization losses.
   train_dir: string specifying where directory where summaries are saved.
   accuracy: tensor float32 batch classification accuracy.
   top_5_accuracy: tensor float32 batch classification accuracy (top_5 classes).

  Returns:
    hooks: summary tensors to be computed at each training step.
    eval_metrics: set to None during training.
    train_op: the optimization term.
  """
  # Rougly drops at every 30k steps.
  boundaries = [30000, 60000, 90000]
  if FLAGS.training_steps_multiplier != 1.0:
    multiplier = FLAGS.training_steps_multiplier
    boundaries = [int(x * multiplier) for x in boundaries]
    tf.logging.info(
        'Learning Rate boundaries are updated with multiplier:%.2f', multiplier)

  learning_rate = tf.train.piecewise_constant(
      global_step,
      boundaries,
      values=[0.1 / (5.**i) for i in range(len(boundaries) + 1)],
      name='lr_schedule')

  optimizer = tf.train.MomentumOptimizer(
      learning_rate, momentum=FLAGS.momentum, use_nesterov=True)

  if training_method == 'set':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseSETOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal)
  elif training_method == 'static':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseStaticOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal)
  elif training_method == 'momentum':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseMomentumOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum,
        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
        grow_init=FLAGS.grow_init,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False)
  elif training_method == 'rigl':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseRigLOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
        frequency=FLAGS.maskupdate_frequency,
        drop_fraction=FLAGS.drop_fraction,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal,
        initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False)
  elif training_method == 'snip':
    optimizer = sparse_optimizers.SparseSnipOptimizer(
        optimizer, mask_init_method=FLAGS.mask_init_method,
        default_sparsity=FLAGS.end_sparsity, use_tpu=False)
  elif training_method in ('scratch', 'baseline', 'prune'):
    pass
  else:
    raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method)
  # Create the training op
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(total_loss, global_step)

  if training_method == 'prune':
    # construct the necessary hparams string from the FLAGS
    hparams_string = ('begin_pruning_step={0},'
                      'sparsity_function_begin_step={0},'
                      'end_pruning_step={1},'
                      'sparsity_function_end_step={1},'
                      'target_sparsity={2},'
                      'pruning_frequency={3},'
                      'threshold_decay=0,'
                      'use_tpu={4}'.format(
                          FLAGS.sparsity_begin_step,
                          FLAGS.sparsity_end_step,
                          FLAGS.end_sparsity,
                          FLAGS.pruning_frequency,
                          False,
                      ))
    # Parse pruning hyperparameters
    pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)

    # Create a pruning object using the pruning hyperparameters
    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

    tf.logging.info('starting mask update op')

    # We override the train op to also update the mask.
    with tf.control_dependencies([train_op]):
      train_op = pruning_obj.conditional_mask_update_op()

  masks = pruning.get_masks()
  mask_metrics = utils.mask_summaries(masks)
  for name, tensor in mask_metrics.items():
    tf.summary.scalar(name, tensor)

  tf.summary.scalar('learning_rate', learning_rate)
  tf.summary.scalar('accuracy', accuracy)
  tf.summary.scalar('total_loss', total_loss)
  tf.summary.scalar('top_5_accuracy', top_5_accuracy)
  # Logging drop_fraction if dynamic sparse training.
  if training_method in ('set', 'momentum', 'rigl', 'static'):
    tf.summary.scalar('drop_fraction', optimizer.drop_fraction)

  summary_op = tf.summary.merge_all()
  summary_hook = tf.train.SummarySaverHook(
      save_secs=300, output_dir=train_dir, summary_op=summary_op)
  hooks = [summary_hook]
  eval_metrics = None

  return hooks, eval_metrics, train_op