Example #1
0
def resnet_model_fn_w_pruning(features, labels, mode, params):
    """The model_fn for ResNet-50 with pruning.

  Args:
    features: A float32 batch of images.
    labels: A int32 batch of labels.
    mode: Specifies whether training or evaluation.
    params: Dictionary of parameters passed to the model.

  Returns:
    A TPUEstimatorSpec for the model
  """

    width = 1. if FLAGS.width <= 0 else FLAGS.width
    if isinstance(features, dict):
        features = features['feature']

    if FLAGS.data_format == 'channels_first':
        assert not FLAGS.transpose_input  # channels_first only for GPU
        features = tf.transpose(features, [0, 3, 1, 2])

    if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT:
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    # Normalize the image to zero mean and unit variance.
    features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
    features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

    pruning_method = params['pruning_method']
    use_tpu = params['use_tpu']
    log_alpha_threshold = params['log_alpha_threshold']

    def build_network():
        """Construct the network in the graph."""
        model_pruning_method = pruning_method
        if pruning_method == 'scratch':
            model_pruning_method = 'threshold'

        network = resnet_model.resnet_v1_(
            resnet_depth=FLAGS.resnet_depth,
            num_classes=FLAGS.num_label_classes,
            # we need to construct the model with the pruning masks, but they won't
            # be updated if we're doing scratch training
            pruning_method=model_pruning_method,
            init_method=FLAGS.init_method,
            width=width,
            prune_first_layer=FLAGS.prune_first_layer,
            prune_last_layer=FLAGS.prune_last_layer,
            data_format=FLAGS.data_format,
            end_sparsity=FLAGS.end_sparsity,
            clip_log_alpha=FLAGS.clip_log_alpha,
            log_alpha_threshold=log_alpha_threshold,
            weight_decay=FLAGS.weight_decay)
        return network(inputs=features,
                       is_training=(mode == tf.estimator.ModeKeys.TRAIN))

    if FLAGS.precision == 'bfloat16':
        with contrib_tpu.bfloat16_scope():
            logits = build_network()
        logits = tf.cast(logits, tf.float32)
    elif FLAGS.precision == 'float32':
        logits = build_network()

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    output_dir = params['output_dir']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes)

    # make sure we reuse the same label smoothing parameter is we're doing
    # scratch / lottery ticket experiments.
    label_smoothing = FLAGS.label_smoothing
    if FLAGS.pruning_method == 'scratch':
        label_smoothing = float(FLAGS.load_mask_dir.split('/')[15])
    loss = tf.losses.softmax_cross_entropy(logits=logits,
                                           onehot_labels=one_hot_labels,
                                           label_smoothing=label_smoothing)
    # Add regularization loss term
    loss += tf.losses.get_regularization_loss()

    if pruning_method == 'variational_dropout':
        reg_loss = utils.variational_dropout_dkl_loss(
            reg_scalar=FLAGS.reg_scalar,
            start_reg_ramp_up=FLAGS.sparsity_begin_step,
            end_reg_ramp_up=FLAGS.sparsity_end_step,
            warm_up=FLAGS.is_warm_up,
            use_tpu=use_tpu)
        loss += reg_loss
        tf.losses.add_loss(reg_loss, loss_collection=tf.GraphKeys.LOSSES)
    elif pruning_method == 'l0_regularization':
        reg_loss = utils.l0_regularization_loss(
            reg_scalar=FLAGS.reg_scalar,
            start_reg_ramp_up=FLAGS.sparsity_begin_step,
            end_reg_ramp_up=FLAGS.sparsity_end_step,
            warm_up=FLAGS.is_warm_up,
            use_tpu=use_tpu)
        loss += reg_loss
        tf.losses.add_loss(reg_loss, loss_collection=tf.GraphKeys.LOSSES)

    host_call = None
    if mode == tf.estimator.ModeKeys.TRAIN:
        host_call, train_op = train_function(pruning_method, loss, output_dir,
                                             use_tpu)

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Calculate eval metrics."""
            logging.info('In metric function')
            eval_metrics = {}
            predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            eval_metrics['top_5_eval_accuracy'] = tf.metrics.mean(in_top_5)
            eval_metrics['eval_accuracy'] = tf.metrics.accuracy(
                labels=labels, predictions=predictions)

            return eval_metrics

        def vd_metric_fn(labels, logits, global_sparsity):
            eval_metrics = metric_fn(labels, logits)
            eval_metrics['global_sparsity'] = tf.metrics.mean(global_sparsity)
            return eval_metrics

        tensors = [labels, logits]
        metric_function = metric_fn

        if FLAGS.pruning_method == 'variational_dropout':
            batch_size = labels.shape[0]
            ones = tf.ones([batch_size, 1])
            mask_metrics = utils.add_vd_pruning_summaries(
                threshold=FLAGS.log_alpha_threshold)
            tensors.append(mask_metrics['global_sparsity'] * ones)
            metric_function = vd_metric_fn

        eval_metrics = (metric_function, tensors)

    # define a custom scaffold function to enable initializing the mask from an
    # already trained checkpoint.
    def initialize_mask_from_ckpt(ckpt_path):
        """Load mask from an existing checkpoint."""
        model_dir = FLAGS.output_dir
        already_has_ckpt = model_dir and tf.train.latest_checkpoint(
            model_dir) is not None
        if already_has_ckpt:
            tf.logging.info(
                'Training already started on this model, not loading masks from'
                'previously trained model')
            return

        reader = tf.train.NewCheckpointReader(ckpt_path)
        mask_names = reader.get_variable_to_shape_map().keys()
        mask_names = [x for x in mask_names if x.endswith('mask')]

        variable_map = {}
        for var in tf.global_variables():
            var_name = var.name.split(':')[0]
            if var_name in mask_names:
                tf.logging.info('Loading mask variable from checkpoint: %s',
                                var_name)
                variable_map[var_name] = var
            elif 'mask' in var_name:
                tf.logging.info(
                    'Cannot find mask variable in checkpoint, skipping: %s',
                    var_name)
        tf.train.init_from_checkpoint(ckpt_path, variable_map)

    def initialize_parameters_from_ckpt(ckpt_path):
        """Load parameters from an existing checkpoint."""
        model_dir = FLAGS.output_dir
        already_has_ckpt = model_dir and tf.train.latest_checkpoint(
            model_dir) is not None
        if already_has_ckpt:
            tf.logging.info(
                'Training already started on this model, not loading masks from'
                'previously trained model')
            return

        reader = tf.train.NewCheckpointReader(ckpt_path)
        param_names = reader.get_variable_to_shape_map().keys()
        param_names = [x for x in param_names if not x.endswith('mask')]

        variable_map = {}
        for var in tf.global_variables():
            var_name = var.name.split(':')[0]
            if var_name in param_names:
                tf.logging.info(
                    'Loading parameter variable from checkpoint: %s', var_name)
                variable_map[var_name] = var
            elif 'mask' not in var_name:
                tf.logging.info(
                    'Cannot find parameter variable in checkpoint, skipping: %s',
                    var_name)
        tf.train.init_from_checkpoint(ckpt_path, variable_map)

    if FLAGS.pruning_method == 'scratch':
        if FLAGS.load_mask_dir:

            def scaffold_fn():
                initialize_mask_from_ckpt(FLAGS.load_mask_dir)
                if FLAGS.initial_value_checkpoint:
                    initialize_parameters_from_ckpt(
                        FLAGS.initial_value_checkpoint)
                return tf.train.Scaffold()
        else:
            raise ValueError(
                'Must supply a mask directory to use scratch method')
    else:
        scaffold_fn = None

    return contrib_tpu.TPUEstimatorSpec(mode=mode,
                                        loss=loss,
                                        train_op=train_op,
                                        host_call=host_call,
                                        eval_metrics=eval_metrics,
                                        scaffold_fn=scaffold_fn)
  def testMultipleConvMaskAdded(self, pruning_method):

    tf.reset_default_graph()
    g = tf.Graph()
    with g.as_default():
      number_of_layers = 5

      kernel_size = [3, 3]
      base_depth = 4
      depth_step = 7

      input_tensor = tf.ones((8, self.height, self.width, base_depth))

      top_layer = input_tensor

      for ix in range(number_of_layers):
        units = base_depth + (ix + 1) * depth_step
        top_layer = pruning_layers.sparse_conv2d(
            x=top_layer,
            units=units,
            kernel_size=kernel_size,
            is_training=False,
            sparsity_technique=pruning_method)

      if pruning_method == 'variational_dropout':
        theta_logsigma2 = tf.get_collection(
            vd.layers.THETA_LOGSIGMA2_COLLECTION)
        self.assertLen(theta_logsigma2, number_of_layers)

        utils.add_vd_pruning_summaries(theta_logsigma2, threshold=3.0)

        dkl_loss_1 = utils.variational_dropout_dkl_loss(
            reg_scalar=1,
            start_reg_ramp_up=0,
            end_reg_ramp_up=1000,
            warm_up=False,
            use_tpu=False)
        dkl_loss_1 = tf.reshape(dkl_loss_1, [1])

        dkl_loss_2 = utils.variational_dropout_dkl_loss(
            reg_scalar=5,
            start_reg_ramp_up=0,
            end_reg_ramp_up=1000,
            warm_up=False,
            use_tpu=False)
        dkl_loss_2 = tf.reshape(dkl_loss_2, [1])

        for ix in range(number_of_layers):
          self.assertListEqual(theta_logsigma2[ix][0].get_shape().as_list(), [
              kernel_size[0], kernel_size[1], base_depth + ix * depth_step,
              base_depth + (ix + 1) * depth_step
          ])

        init_op = tf.global_variables_initializer()

        with self.test_session() as sess:
          sess.run(init_op)
          if pruning_method == 'variational_dropout':
            loss_1, loss_2 = sess.run([dkl_loss_1, dkl_loss_2])

            self.assertGreater(loss_2, loss_1)
      elif pruning_method == 'l0_regularization':
        theta_logalpha = tf.get_collection(
            l0.layers.THETA_LOGALPHA_COLLECTION)
        self.assertLen(theta_logalpha, number_of_layers)

        utils.add_l0_summaries(theta_logalpha)

        l0_norm_loss_1 = utils.l0_regularization_loss(
            reg_scalar=1,
            start_reg_ramp_up=0,
            end_reg_ramp_up=1000,
            warm_up=False,
            use_tpu=False)
        l0_norm_loss_1 = tf.reshape(l0_norm_loss_1, [1])

        l0_norm_loss_2 = utils.l0_regularization_loss(
            reg_scalar=5,
            start_reg_ramp_up=0,
            end_reg_ramp_up=1000,
            warm_up=False,
            use_tpu=False)
        l0_norm_loss_2 = tf.reshape(l0_norm_loss_2, [1])

        for ix in range(number_of_layers):
          self.assertListEqual(theta_logalpha[ix][0].get_shape().as_list(), [
              kernel_size[0], kernel_size[1], base_depth + ix * depth_step,
              base_depth + (ix + 1) * depth_step
          ])

        init_op = tf.global_variables_initializer()

        with self.test_session() as sess:
          sess.run(init_op)
          loss_1, loss_2 = sess.run([l0_norm_loss_1, l0_norm_loss_2])
          self.assertGreater(loss_2, loss_1)
      else:
        mask = tf.get_collection(core.MASK_COLLECTION)
        for ix in range(number_of_layers):
          self.assertListEqual(mask[ix].get_shape().as_list(), [
              kernel_size[0], kernel_size[1], base_depth + ix * depth_step,
              base_depth + (ix + 1) * depth_step
          ])