예제 #1
0
 def scaffold_fn():
   """For initialization, passed to the estimator."""
   utils.initialize_parameters_from_ckpt(FLAGS.load_mask_dir,
                                         FLAGS.output_dir, MASK_SUFFIX)
   if FLAGS.initial_value_checkpoint:
     utils.initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint,
                                           FLAGS.output_dir, PARAM_SUFFIXES)
   return tf.train.Scaffold()
예제 #2
0
 def scaffold_fn():
   """For initialization, passed to the estimator."""
   if FLAGS.initial_value_checkpoint:
     utils.initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint,
                                           FLAGS.output_dir, PARAM_SUFFIXES)
   all_masks = pruning.get_masks()
   assigner = sparse_utils.get_mask_init_fn(
       all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity,
       CUSTOM_SPARSITY_MAP)
   def init_fn(scaffold, session):
     """A callable for restoring variable from a checkpoint."""
     del scaffold  # Unused.
     session.run(assigner)
   return tf.train.Scaffold(init_fn=init_fn)
예제 #3
0
def wide_resnet_w_pruning(features, labels, mode, params):
    """The model_fn for ResNet wide with pruning.

  Args:
    features: A float32 batch of images.
    labels: A int32 batch of labels.
    mode: Specifies whether training or evaluation.
    params: Dictionary of parameters passed to the model.

  Returns:
    A EstimatorSpec for the model

  Raises:
      ValueError: if mode is not recognized as train or eval.
  """

    if isinstance(features, dict):
        features = features['feature']

    train_dir = params['train_dir']
    training_method = params['training_method']

    global_step, accuracy, top_5_accuracy, logits = build_model(
        mode=mode,
        images=features,
        labels=labels,
        training_method=training_method,
        num_classes=FLAGS.num_classes,
        depth=FLAGS.resnet_depth,
        width=FLAGS.resnet_width)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    with tf.name_scope('computing_cross_entropy_loss'):
        entropy_loss = tf.losses.sparse_softmax_cross_entropy(labels=labels,
                                                              logits=logits)
        tf.summary.scalar('cross_entropy_loss', entropy_loss)

    with tf.name_scope('computing_total_loss'):
        total_loss = tf.losses.get_total_loss(add_regularization_losses=True)

    if mode == tf.estimator.ModeKeys.TRAIN:
        hooks, eval_metrics, train_op = train_fn(training_method, global_step,
                                                 total_loss, train_dir,
                                                 accuracy, top_5_accuracy)
    elif mode == tf.estimator.ModeKeys.EVAL:
        hooks = None
        train_op = None
        with tf.name_scope('summaries'):
            eval_metrics = create_eval_metrics(labels, logits)
    else:
        raise ValueError('mode not recognized as training or eval.')

    # If given load parameter values.
    if FLAGS.initial_value_checkpoint:
        tf.logging.info('Loading inital values from: %s',
                        FLAGS.initial_value_checkpoint)
        utils.initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint,
                                              FLAGS.train_dir, PARAM_SUFFIXES)

    # Load or randomly initialize masks.
    if (FLAGS.load_mask_dir
            and FLAGS.training_method not in ('snip', 'baseline', 'prune')):
        # Init masks.
        tf.logging.info('Loading masks from %s', FLAGS.load_mask_dir)
        utils.initialize_parameters_from_ckpt(FLAGS.load_mask_dir,
                                              FLAGS.train_dir, MASK_SUFFIX)
        scaffold = tf.train.Scaffold()
    elif (FLAGS.mask_init_method and FLAGS.training_method
          not in ('snip', 'baseline', 'scratch', 'prune')):
        tf.logging.info('Initializing masks using method: %s',
                        FLAGS.mask_init_method)
        all_masks = pruning.get_masks()
        assigner = sparse_utils.get_mask_init_fn(all_masks,
                                                 FLAGS.mask_init_method,
                                                 FLAGS.end_sparsity, {})

        def init_fn(scaffold, session):
            """A callable for restoring variable from a checkpoint."""
            del scaffold  # Unused.
            session.run(assigner)

        scaffold = tf.train.Scaffold(init_fn=init_fn)
    else:
        assert FLAGS.training_method in ('snip', 'baseline', 'prune')
        scaffold = None
        tf.logging.info('No mask is set, starting dense.')

    return tf.estimator.EstimatorSpec(mode=mode,
                                      training_hooks=hooks,
                                      loss=total_loss,
                                      train_op=train_op,
                                      eval_metric_ops=eval_metrics,
                                      scaffold=scaffold)