예제 #1
0
 def testDefaultBatch(self):
   # Model returns logits for 3 samples and 5 classes.
   n_sample, chunk_size, n_out = 42, 10, 10
   model = self._create_mock_model(n_out=n_out)
   x_all = tf.ones((n_sample, 2))
   y_all = tf.ones((n_sample,), dtype=tf.int32)
   d = tf.data.Dataset.from_tensor_slices((x_all, y_all))
   loss, acc, total_samples = train_utils.cross_entropy_loss(
       model, d.batch(chunk_size))
   self.assertEqual(total_samples, n_sample)
   self.assertIsNone(acc)
   all_logits = []
   for i in range(0, n_sample, chunk_size):
     c_size = min(n_sample, i + chunk_size) - i
     all_logits.append(self.get_logits(c_size, n_out))
   logits = tf.concat(all_logits, 0)
   cce = tf.keras.losses.SparseCategoricalCrossentropy()
   true_loss = cce(y_all, logits)
   self.assertAllClose(loss, true_loss, atol=1e-4)
   self.assertEqual(model.call_count, 5)
   model.get_layer_keys.assert_not_called()
   train_utils.cross_entropy_loss(model, d.batch(chunk_size),
                                  aggregate_values=True)
   model.get_layer_keys.assert_called_once()
   model.conv_1.reset_saved_values.assert_called_once()
   model.conv_2.reset_saved_values.assert_called_once()
예제 #2
0
def get_pruning_measurements(model, subset_val, layers2prune):
    """Returns 6 different pruning scores and some other measurements.

  Args:
    model: tf.keras.Model, model to be pruned.
    subset_val:  tf.data.Dataset, to calculate data-dependent scoring functions.
    layers2prune: list<str>, of layers which will be scored.
  Returns:
    dict, scores with 6 keys `mrs`, `abs_mrs`, `rs`, `abs_rs`, `norm`, `rand`.
      and score dictionary values. Each Score dictionary has scorings for each
      layer provided in `layers2prune`.
    dict, mean unit activations for each layer in layers2prune.
    dict, l2norms, average square activations for each unit.
      see `deadunits.layers.TaylorScorer` for further details.
  """
    # Run once and get mrs, rs and mean calculated
    train_utils.cross_entropy_loss(model,
                                   subset_val,
                                   training=False,
                                   compute_mean_replacement_saliency=True,
                                   compute_removal_saliency=True,
                                   is_abs=True,
                                   aggregate_values=True,
                                   run_gradient=True)
    scores = collections.defaultdict(dict)
    mean_values = {}
    l2_norms = {}
    for l_name in layers2prune:
        l_ts = getattr(model, l_name + '_ts')
        scores['abs_mrs'][l_name] = l_ts.get_saved_values('mrs')
        scores['abs_rs'][l_name] = l_ts.get_saved_values('rs')
        mean_values[l_name] = l_ts.get_saved_values('mean')
        l2_norms[l_name] = l_ts.get_saved_values('l2norm')

    # Run again to get without abs.
    train_utils.cross_entropy_loss(model,
                                   subset_val,
                                   training=False,
                                   compute_mean_replacement_saliency=True,
                                   compute_removal_saliency=True,
                                   is_abs=False,
                                   aggregate_values=True,
                                   run_gradient=True)

    for l_name in layers2prune:
        l_ts = getattr(model, l_name + '_ts')
        l = getattr(model, l_name)
        scores['mrs'][l_name] = l_ts.get_saved_values('mrs')
        scores['rs'][l_name] = l_ts.get_saved_values('rs')
        # The reson we calculate them here to reduce the amount of the code.
        # `weights[0]` is the weight, where `weights[1]` is the bias.
        scores['norm'][l_name] = unitscorers.norm_score(
            l.get_layer().weights[0])
        scores['rand'][l_name] = unitscorers.random_score(
            l.get_layer().weights[0])
    return (scores, mean_values, l2_norms)
예제 #3
0
 def testSingleBatch(self):
   # Model returns logits for 3 samples and 5 classes.
   n_sample, n_out = 8, 10
   model = self._create_mock_model(n_out=n_out)
   x = tf.ones((n_sample, 2))
   y = tf.ones((n_sample,), dtype=tf.int32)
   loss, acc, total_samples = train_utils.cross_entropy_loss(
       model, (x, y), calculate_accuracy=True)
   model2 = self._create_mock_model(n_out=n_out)
   d = tf.data.Dataset.from_tensor_slices((x, y))
   loss2, acc2, total_samples2 = train_utils.cross_entropy_loss(
       model2, d.batch(n_sample), calculate_accuracy=True)
   self.assertEqual(total_samples2, total_samples)
   self.assertAllClose(acc, acc2)
   self.assertAllClose(loss, loss2)
예제 #4
0
 def testDefaultAccuracy(self):
   # Model returns logits for 3 samples and 5 classes.
   n_sample, n_out = 8, 10
   model = self._create_mock_model(n_out=n_out)
   x = tf.ones((n_sample, 2))
   y = tf.ones((n_sample,), dtype=tf.int32)
   _, acc, total_samples = train_utils.cross_entropy_loss(
       model, (x, y), calculate_accuracy=True)
   self.assertEqual(total_samples, n_sample)
   logits = self.get_logits(n_sample, n_out)
   predictions = tf.cast(tf.argmax(logits, 1), y.dtype)
   acc_obj = tf.keras.metrics.Accuracy()
   acc_obj.update_state(tf.squeeze(y), predictions)
   true_acc = acc_obj.result().numpy()
   self.assertAllClose(acc, true_acc)
예제 #5
0
 def testDefaultSingleBatch(self):
   # Model returns logits for 3 samples and 5 classes.
   n_sample, n_out = 8, 10
   model = self._create_mock_model(n_out=n_out)
   x = tf.ones((n_sample, 2))
   y = tf.ones((n_sample,), dtype=tf.int32)
   loss, acc, total_samples = train_utils.cross_entropy_loss(model, (x, y))
   self.assertEqual(total_samples, n_sample)
   self.assertIsNone(acc)
   logits = self.get_logits(n_sample, n_out)
   cce = tf.keras.losses.SparseCategoricalCrossentropy()
   true_loss = cce(y, logits)
   self.assertAllClose(loss, true_loss)
   model.assert_called_once_with(
       x,
       training=False,
       compute_mean_replacement_saliency=False,
       compute_removal_saliency=False,
       is_abs=True,
       aggregate_values=False)
예제 #6
0
def prune_layer_and_eval(dataset_name='cifar10',
                         pruning_methods=(('rand', True), ),
                         model_dir=gin.REQUIRED,
                         l_name=gin.REQUIRED,
                         n_exps=gin.REQUIRED,
                         val_size=gin.REQUIRED,
                         max_pruning_fraction=gin.REQUIRED):
    """Loads and prunes a model with various sparsity targets.

  This function assumes that the `seed_dir` exists
  Args:
    dataset_name: str, either 'cifar10' or 'imagenet'.
    pruning_methods: iterator of tuples, (scoring, is_bp) where `is_bp`
      is a boolean indicating the usage of Mean Replacement Pruning. Scoring is
      a string from ['norm', 'mrs', 'rs', 'rand', 'abs_mrs', 'rs'].
        'norm': unitscorers.norm_score
        '{abs_}mrs': `compute_mean_replacement_saliency=True` for `TaylorScorer`
          layer. if {abs_} prefix exists absolute value used before aggregation.
          i.e. is_abs=True.
        '{abs_}rs': `compute_removal_saliency=True` for `TaylorScorer` layer. if
          {abs_} prefix exists absolute value used before aggregation. i.e.
          is_abs=True.
        'rand': unitscorers.random_score is_bp; bool, if True, mean value of the
          units are propagated to the next layer prior to the pruning. `bp` for
          bias_propagation.
    model_dir: str, Path to the checkpoint directory.
    l_name: str, a valid layer name from the model loaded.
    n_exps: int, number of pruning experiments to be made. This number is used
      generate pruning counts for different experiments.
    val_size: int, size for the first dataset, passed to the `get_datasets`.
    max_pruning_fraction: float, max sparsity for pruning. Multiplying this
      number with the total number of units, we would get the upper limit for
      the pruning_count.

  Raises:
    AssertionError: when no checkpoint is found.
    ValueError: when the scoring function key is not valid.
    OSError: when there is no checkpoint found.
  """
    logging.info('Looking checkpoint at: %s', model_dir)
    latest_cpkt = tf.train.latest_checkpoint(model_dir)
    if not latest_cpkt:
        raise OSError('No checkpoint found in %s' % model_dir)
    logging.info('Using latest checkpoint at %s', latest_cpkt)
    model = model_load.get_model(load_path=latest_cpkt,
                                 dataset_name=dataset_name)
    datasets = data.get_datasets(dataset_name=dataset_name, val_size=val_size)
    _, _, subset_val, subset_test, subset_val2 = datasets
    input_shapes = {l_name: getattr(model, l_name + '_ts').xshape}
    layers2prune = [l_name]
    measurements = pruner.get_pruning_measurements(model, subset_val,
                                                   layers2prune)
    (all_scores, mean_values, _) = measurements
    for scoring, is_bp in pruning_methods:
        if scoring not in pruner.ALL_SCORING_FUNCTIONS:
            raise ValueError('%s is not one of %s' %
                             (scoring, pruner.ALL_SCORING_FUNCTIONS))
        scores = all_scores[scoring]
        d_path = os.path.join(
            FLAGS.outdir,
            '%d-%s-%s-%s.pickle' % (val_size, l_name, scoring, str(is_bp)))
        logging.info(d_path)
        if tf.gfile.Exists(d_path):
            logging.warning('File %s exists, skipping.', d_path)
        else:
            ls_train_loss = []
            ls_train_acc = []
            ls_test_loss = []
            ls_test_acc = []

            n_units = input_shapes[l_name][-1].value
            n_unit_pruned_max = n_units * max_pruning_fraction
            c_slice = np.linspace(0, n_unit_pruned_max, n_exps, dtype=np.int32)
            logging.info('Layer:%s, n_units:%d, c_slice:%s', l_name, n_units,
                         str(c_slice))
            for pruning_count in c_slice:
                # Cast from np.int32 to int.
                pruning_count = int(pruning_count)
                copied_model = model.clone()
                pruning_factor = None
                pruner.prune_model_with_scores(copied_model, scores, is_bp,
                                               layers2prune, pruning_factor,
                                               pruning_count, mean_values,
                                               input_shapes)
                test_loss, test_acc, _ = train_utils.cross_entropy_loss(
                    copied_model, subset_test, calculate_accuracy=True)
                train_loss, train_acc, _ = train_utils.cross_entropy_loss(
                    copied_model, subset_val2, calculate_accuracy=True)
                logging.info('is_bp: %s, n: %d, test_loss%f, train_loss:%f',
                             str(is_bp), pruning_count, test_loss, train_loss)
                ls_train_loss.append(train_loss.numpy())
                ls_test_loss.append(test_loss.numpy())
                ls_test_acc.append(test_acc.numpy())
                ls_train_acc.append(train_acc.numpy())
            utils.pickle_object(
                (ls_train_loss, ls_train_acc, ls_test_loss, ls_test_acc),
                d_path)
예제 #7
0
def prune_and_finetune_model(pruning_schedule=gin.REQUIRED,
                             model_dir=gin.REQUIRED,
                             dataset_name='cifar10',
                             checkpoint_interval=5,
                             log_interval=5,
                             n_finetune=100,
                             epochs=20,
                             lr=1e-4,
                             momentum=0.9):
    """Loads and prunes the model layer by layer according to the schedule.

  The model is finetuned between pruning tasks (as given in the schedule).

  Args:
    pruning_schedule: list<str, float>, where the str is a valid layer name and
      the float is the pruning fraction of that layer. Layers are pruned in
      the order they are given and `n_finetune` steps taken in between.
    model_dir: str, Path to the checkpoint directory.
    dataset_name: str, either 'cifar10' or 'imagenet'.
    checkpoint_interval: int, number of epochs between two checkpoints.
    log_interval: int, number of steps between two logging events.
    n_finetune: int, number of steps between two pruning steps.
    epochs: int, total number of epochs to run.
    lr: float, learning rate for the fine-tuning steps.
    momentum: float, momentum multiplier for the fine-tuning steps.
  Raises:
    ValueError: when the n_finetune is not positive.
    OSError: when there is no checkpoint found.
  """
    if n_finetune <= 0:
        raise ValueError('n_finetune must be positive: given %d' % n_finetune)
    optimizer = tf.keras.optimizers.SGD(learning_rate=lr, momentum=momentum)
    logging.info('Using outdir: %s', model_dir)
    latest_cpkt = tf.train.latest_checkpoint(model_dir)
    if not latest_cpkt:
        raise OSError('No checkpoint found in %s' % model_dir)
    logging.info('Using latest checkpoint at %s', latest_cpkt)
    model = model_load.get_model(load_path=latest_cpkt,
                                 dataset_name=dataset_name)
    logging.info('Model init-config: %s', model.init_config)
    logging.info('Model forward chain: %s', str(model.forward_chain))
    datasets = data.get_datasets(dataset_name=dataset_name)
    dataset_train, dataset_test, subset_val, subset_test, subset_val2 = datasets

    unit_pruner = pruner.UnitPruner(model, subset_val)
    step_counter = optimizer.iteration
    tf.summary.experimental.set_step(step_counter)
    current_epoch = tf.Variable(1)
    current_layer_index = tf.Variable(0)
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     model=model,
                                     current_epoch=current_epoch,
                                     current_layer_index=current_layer_index)
    latest_cpkt = tf.train.latest_checkpoint(FLAGS.outdir)
    if latest_cpkt:
        logging.info('Using latest checkpoint at %s', latest_cpkt)
        # Restore variables on creation if a checkpoint exists.
        checkpoint.restore(latest_cpkt)
        logging.info('Resuming with epoch: %d', current_epoch.numpy())
    c_epoch = current_epoch.numpy()
    c_layer_index = current_layer_index.numpy()
    # Starting from the first batch, we perform pruning every `n_finetune` step.
    # Layers pruned one by one according to the pruning schedule given.

    while c_epoch <= epochs:
        logging.info('Starting Epoch: %d', c_epoch)
        for (x, y) in dataset_train:
            # Every `n_finetune` step perform pruning.
            if (step_counter.numpy() % n_finetune == 0
                    and c_layer_index < len(pruning_schedule)):
                logging.info('Pruning at iteration: %d', step_counter.numpy())
                l_name, pruning_factor = pruning_schedule[c_layer_index]
                unit_pruner.prune_layer(l_name, pruning_factor=pruning_factor)

                train_utils.log_loss_acc(model, subset_val2, subset_test)
                train_utils.log_sparsity(model)
                # Re-init optimizer and therefore remove previous momentum.
                optimizer = tf.keras.optimizers.SGD(learning_rate=lr,
                                                    momentum=momentum)
                c_layer_index += 1
                current_layer_index.assign(c_layer_index)
            else:
                if step_counter.numpy() % log_interval == 0:
                    logging.info('Iteration: %d', step_counter.numpy())
                    train_utils.log_loss_acc(model, subset_val2, subset_test)
                    train_utils.log_sparsity(model)
            with tf.GradientTape() as tape:
                loss_train, _, _ = cross_entropy_loss(model, (x, y),
                                                      training=True)
            grads = tape.gradient(loss_train, model.variables)
            # Updating the model.
            optimizer.apply_gradients(list(zip(grads, model.variables)),
                                      global_step=step_counter)
            if step_counter.numpy() % log_interval == 0:
                tf.summary.scalar('loss_train', loss_train)
                tf.summary.image('x', x, max_outputs=1)
        # End of an epoch.
        c_epoch += 1
        current_epoch.assign(c_epoch)
        # Save every n OR after last epoch.
        if (tf.equal((current_epoch - 1) % checkpoint_interval, 0)
                or c_epoch > epochs):
            # Re-init checkpoint to ensure the masks are captured. The reason for
            # this is that the masks are initially not generated.
            checkpoint = tf.train.Checkpoint(
                optimizer=optimizer,
                model=model,
                current_epoch=current_epoch,
                current_layer_index=current_layer_index)
            logging.info('Checkpoint after epoch: %d', c_epoch - 1)
            checkpoint.save(
                os.path.join(FLAGS.outdir, 'ckpt-%d' % (c_epoch - 1)))

    # Test model
    test_loss, test_acc, n_samples = cross_entropy_loss(
        model, dataset_test, calculate_accuracy=True)
    tf.summary.scalar('test_loss_all', test_loss)
    tf.summary.scalar('test_acc_all', test_acc)
    logging.info(
        'Overall_test_loss: %.4f, Overall_test_acc: %.4f, n_samples: %d',
        test_loss, test_acc, n_samples)
예제 #8
0
def train_model(dataset_name='cifar10',
                checkpoint_every_n_epoch=5,
                log_interval=1000,
                epochs=10,
                lr=1e-2,
                lr_drop_iter=1500,
                lr_decay=0.5,
                momentum=0.9,
                seed=8):
  """Trains the model with regular logging and checkpoints.

  Args:
    dataset_name: str, either 'cifar10' or 'imagenet'.
    checkpoint_every_n_epoch: int, number of epochs between two checkpoints.
    log_interval: int, number of steps between two logging events.
    epochs: int, epoch to train with.
    lr: float, learning rate for the fine-tuning steps.
    lr_drop_iter: int, iteration between two consequtive lr drop.
    lr_decay: float: multiplier for step learning rate reduction.
    momentum: float, momentum multiplier for the fine-tuning steps.
    seed: int, random seed to be set to produce reproducible experiments.

  Raises:
    AssertionError: when the args doesn't match the specs.
  """
  assert dataset_name in ['cifar10', 'imagenet', 'cub200', 'imagenet_vgg']
  tf.random.set_seed(seed)
  # The model is configured through gin parameters.
  model = model_load.get_model(dataset_name=dataset_name)

  # model.call = tf.contrib.eager.defun(model.call)
  datasets = data.get_datasets(dataset_name=dataset_name)
  dataset_train, dataset_test, _, subset_test, subset_val = datasets
  lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
      lr, decay_steps=lr_drop_iter, decay_rate=lr_decay, staircase=True)
  optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=momentum)
  step_counter = optimizer.iterations
  tf.summary.experimental.set_step(step_counter)
  logging.info('Model init-config: %s', model.init_config)
  logging.info('Model forward chain: %s', str(model.forward_chain))

  current_epoch = tf.Variable(1)
  # Create checkpoint object TODO check whether you need ckpt-prefix.
  checkpoint = tf.train.Checkpoint(
      optimizer=optimizer,
      model=model,
      current_epoch=current_epoch)
  # No limits basically by setting to `n_units_target`.
  checkpoint_manager = tf.train.CheckpointManager(
      checkpoint, directory=FLAGS.outdir, max_to_keep=None)

  latest_cpkt = checkpoint_manager.latest_checkpoint
  if latest_cpkt:
    logging.info('Using latest checkpoint at %s', latest_cpkt)
    # Restore variables on creation if a checkpoint exists.
    checkpoint.restore(latest_cpkt)
    logging.info('Resuming with epoch: %d', current_epoch.numpy())
  c_epoch = current_epoch.numpy()
  while c_epoch <= epochs:
    logging.info('Starting Epoch:%d', c_epoch)
    for (x, y) in dataset_train:
      if step_counter % log_interval == 0:
        train_utils.log_loss_acc(model, subset_val, subset_test)
        tf.summary.image('x', x, max_outputs=1)
        logging.info('Iteration:%d', step_counter.numpy())
      with tf.GradientTape() as tape:
        loss_train, _, _ = cross_entropy_loss(model, (x, y), training=True)
      grads = tape.gradient(loss_train, model.variables)
      # Updating the model.
      optimizer.apply_gradients(zip(grads, model.variables))
      tf.summary.scalar('loss_train', loss_train)
      tf.summary.scalar('lr', optimizer.lr(step_counter))
    # End of an epoch.
    c_epoch += 1
    current_epoch.assign(c_epoch)
    # Save every n OR after last epoch.
    if (tf.equal((current_epoch - 1) % checkpoint_every_n_epoch, 0) or
        c_epoch > epochs):
      logging.info('Checkpoint after epoch: %d', c_epoch - 1)
      checkpoint_manager.save(checkpoint_number=c_epoch - 1)
  test_loss, test_acc, n_samples = cross_entropy_loss(
      model, dataset_test, calculate_accuracy=True)
  tf.summary.scalar('test_loss_all', test_loss)
  tf.summary.scalar('test_acc_all', test_acc)
  logging.info('Overall_test_loss:%.4f, Overall_test_acc:%.4f, n_samples:%d',
               test_loss, test_acc, n_samples)
예제 #9
0
def prune_and_finetune_model(dataset_name='imagenet_vgg',
                             flop_regularizer=0,
                             n_units_target=4000,
                             checkpoint_interval=5,
                             log_interval=5,
                             n_finetune=100,
                             lr=1e-4,
                             momentum=0.9,
                             seed=8):
    """Trains the model with regular logging and checkpoints.

  Args:
    dataset_name: str, dataset to train on.
    flop_regularizer: float, multiplier for the flop regularization. If 0, no
      regularization is made during pruning.
    n_units_target: int, number of unit to prune.
    checkpoint_interval: int, number of epochs between two checkpoints.
    log_interval: int, number of steps between two logging events.
    n_finetune: int, number of steps between two pruning steps. Starting from
      the first iteration we prune 1 unit every `n_finetune` gradient update.
    lr: float, learning rate for the fine-tuning steps.
    momentum: float, momentum multiplier for the fine-tuning steps.
    seed: int, random seed to be set to produce reproducible experiments.
  Raises:
    ValueError: when the n_finetune is not positive.
  """
    if n_finetune <= 0:
        raise ValueError('n_finetune must be positive: given %d' % n_finetune)
    tf.random.set_seed(seed)
    optimizer = tf.keras.optimizers.SGD(lr, momentum=momentum)
    # imagenet_vgg->imagenet
    dataset_basename = dataset_name.split('_')[0]
    model = model_load.get_model(dataset_name=dataset_basename)
    # Uncomment following if your model is (defunable).
    # model.call = tf.function(model.call)
    datasets = data.get_datasets(dataset_name=dataset_name)
    (dataset_train, _, subset_val, subset_test, subset_val2) = datasets
    logging.info('Model init-config: %s', model.init_config)
    logging.info('Model forward chain: %s', str(model.forward_chain))

    unit_pruner = pruner.UnitPruner(model, subset_val)
    # We prune all conv layers.
    pruning_pool = _all_vgg_conv_layers
    baselines = {
        l_name: c_flop * flop_regularizer
        for l_name, c_flop in zip(pruning_pool, _vgg16_flop_regularizition)
    }

    step_counter = optimizer.iterations
    tf.summary.experimental.set_step(step_counter)
    c_pruning_step = tf.Variable(1)
    # Create checkpoint object TODO check whether you need ckpt-prefix.
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     model=model,
                                     c_pruning_step=c_pruning_step)
    # No limits basically by setting to `n_units_target`.
    checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                    directory=FLAGS.outdir,
                                                    max_to_keep=None)

    latest_cpkt = checkpoint_manager.latest_checkpoint
    if latest_cpkt:
        logging.info('Using latest checkpoint at %s', latest_cpkt)
        # Restore variables on creation if a checkpoint exists.
        checkpoint.restore(latest_cpkt)
        logging.info('Resuming with pruning step: %d', c_pruning_step.numpy())
    pruning_step = c_pruning_step.numpy()
    while pruning_step <= n_units_target:
        for (x, y) in dataset_train:
            # Every `n_finetune` step perform pruning.
            if step_counter.numpy() % n_finetune == 0:
                if pruning_step > n_units_target:
                    # This holds true when we prune last time and fine tune N many
                    # iterations. We would break and the while loop above would break,
                    # too.
                    break
                tf.logging.info('Pruning Step:%d', pruning_step)
                start = time.time()
                unit_pruner.prune_one_unit(pruning_pool, baselines=baselines)
                end = time.time()
                tf.logging.info(
                    '\nTrain time for Pruning Step #%d (step %d): %f',
                    pruning_step,
                    tf.train.get_or_create_global_step().numpy(), end - start)
                pruning_step += 1
                c_pruning_step.assign(pruning_step)
                if tf.equal((pruning_step - 1) % checkpoint_interval, 0):
                    checkpoint_manager.save()
            if step_counter.numpy() % log_interval == 0:
                train_utils.log_loss_acc(model, subset_val2, subset_test)
                train_utils.log_sparsity(model)
            with tf.GradientTape() as tape:
                loss_train, _, _ = cross_entropy_loss(model, (x, y),
                                                      training=True)
            grads = tape.gradient(loss_train, model.variables)
            # Updating the model.
            optimizer.apply_gradients(list(zip(grads, model.variables)),
                                      global_step=step_counter)
            if step_counter.numpy() % log_interval == 0:
                tf.summary.scalar('loss_train', loss_train)
                tf.summary.image('x', x, max_outputs=1)
예제 #10
0
    def prune_layer(self,
                    layer_name,
                    pruning_factor=0.1,
                    pruning_count=None,
                    pruning_method=None,
                    is_bp=None):
        """Prunes a single layer using the given scoring function.

    Args:
      layer_name: str, layer name to prune.
      pruning_factor: float, 0 < pruning_factor < 1.
      pruning_count: int, if not None, sets the pruning_factor to None. This is
        because you can either prune a fraction or a number of units.
        pruning_count is used to determine how many units to prune per layer.
      pruning_method: str, from ['norm', 'mrs', 'rs', 'rand', 'abs_mrs', 'rs'].
        If given, overwrites the default value.
     is_bp: bool, if True Mean Replacement Pruning is used and bias propagation
       is made. If given, overwrites the default value.

    Raises:
      ValueError: if the arguments provided doesn't match specs.
    """
        pruning_method = pruning_method if pruning_method else self.pruning_method
        is_bp = is_bp if is_bp else self.is_bp
        if pruning_method not in ALL_SCORING_FUNCTIONS:
            raise ValueError('%s is not one of %s' %
                             (pruning_method, ALL_SCORING_FUNCTIONS))
        # Need to wrap up the layer in a list for
        # `pruner.prune_model_with_scores()` call.
        layers2prune = [layer_name]
        # If pruning_count exists invalidate the pruning_factor.
        if pruning_count is not None:
            if not isinstance(pruning_count, int):
                raise ValueError('pruning_count: %s should be an int' %
                                 pruning_count)
            elif pruning_count < 1:
                raise ValueError(
                    'pruning_count: %d should be greater than 1.' %
                    pruning_count)
            pruning_factor = None
        # Validate pruning_factor.
        elif pruning_factor == 0:
            return
        elif pruning_factor <= 0 or pruning_factor >= 1:
            raise ValueError('pruning_factor: %s should be in (0, 1)' %
                             pruning_factor)

        logging.info('Pruning layer `%s` with: %s, f:%.2f', layer_name,
                     pruning_method, pruning_factor)
        input_shapes = {
            layer_name: getattr(self.model, layer_name + '_ts').xshape
        }

        # Calculating the scoring function/mean value.
        is_abs = pruning_method.startswith('abs')
        is_mrs = pruning_method.endswith('mrs')
        is_rs = pruning_method.endswith('rs') and not is_mrs
        train_utils.cross_entropy_loss(
            self.model,
            self.subset_val,
            training=False,
            compute_mean_replacement_saliency=is_mrs,
            compute_removal_saliency=is_rs,
            is_abs=is_abs,
            aggregate_values=True,
            run_gradient=True)
        scores = {}
        mean_values = {}
        # `layer_ts` stands for TaylorScorer layer.
        layer_ts = getattr(self.model, layer_name + '_ts')
        masked_layer = getattr(self.model, layer_name)
        masked_layer.apply_masks()
        if pruning_method == 'rand':
            scores[layer_name] = unitscorers.random_score(
                masked_layer.get_layer().weights[0])
        elif pruning_method == 'norm':
            scores[layer_name] = unitscorers.norm_score(
                masked_layer.get_layer().weights[0])
        else:
            # mrs or rs.
            score_name = 'rs' if is_rs else 'mrs'
            scores[layer_name] = layer_ts.get_saved_values(score_name)
        mean_values[layer_name] = layer_ts.get_saved_values('mean')
        prune_model_with_scores(self.model, scores, is_bp, layers2prune,
                                pruning_factor, pruning_count, mean_values,
                                input_shapes)
예제 #11
0
def probe_pruning(model,
                  subset_val,
                  subset_val2,
                  subset_test,
                  f_retrain,
                  baselines=(0.0, 0.0),
                  layers2prune='all',
                  n_retrain=0,
                  pruning_factor=0.1,
                  pruning_count=None,
                  pruning_methods=(('norm', True), )):
    """Prunes a copy of the network and calculates change in the loss.

  By default calculates mrs,rs and mean values.
  Args:
    model: tf.keras.Model
    subset_val: tf.data.Dataset, used for loss calculation.
    subset_val2: tf.data.Dataset, used for pruning scoring.
    subset_test: tf.data.Dataset, from test set.
    f_retrain: function, used for retraining with 2 arguments `copied_model` and
      `n_retrain`.
    baselines: tuple, <val_loss, test_loss> Baselines to subtract from loss,
    layers2prune: list or str, each elemenet `name` in the list should be a
      valid MasketLayer under model. model.name->MaskedLayer. One can also
      provide following tokens:
        `all`: searches model finds all MaskedLayers's and prunes them all.
        `firstconv`: prunes the first conv_layer in the `forward_chain`.
        `midconv`: prunes the `mid=n_conv//2` conv_layer in the `forward_chain`.
        `lastconv`: prunes the last conv_layer in the `forward_chain`.
        `firstdense`: prunes the first dense layer in the `forward_chain`.
    n_retrain: int, Number of retraining updates to perform after pruning.
      If n_retrain<=0, then nothing happens.
    pruning_factor: float, 0<pruning_factor<1
    pruning_count: int, if not None, sets the pruning_factor to None. This is
      because you can either prune a fraction or a number of units.
      pruning_count is used to determine how many units to prune per layer.
    pruning_methods: iterator of tuples, (scoring, is_bp) where `is_bp`
      is a boolean indicating the usage of Mean Replacement Pruning. Scoring is
      a string from ['norm', 'mrs', 'rs', 'rand', 'abs_mrs', 'rs'].
        'norm': unitscorers.norm_score
        '{abs_}mrs': `compute_mean_replacement_saliency=True` for `TaylorScorer`
          layer. if {abs_} prefix exists absolute value used before aggregation.
          i.e. is_abs=True.
        '{abs_}rs': `compute_removal_saliency=True` for `TaylorScorer` layer. if
          {abs_} prefix exists absolute value used before aggregation. i.e.
          is_abs=True.
        'rand': unitscorers.random_score is_bp; bool, if True, mean value of the
          units are propagated to the next layer prior to the pruning. `bp` for
          bias_propagation.

  Raises:
    AssertionError: if the arguments provided doesn't match specs.

  Returns:
    selected_units: dict, keys coming from `layers2prune` and each value is a
      tuple of (score, mask), where score is the pruning scores for each units
      and mask is the corresponding binary masks created for the given fraction.
    mean_values: dict, keys coming from `layers2prune` and each value is the
      mean activation under training batch.
  """
    # Check validity of `layers2prune` and process.
    layers2prune = process_layers2prune(layers2prune, model)
    # If pruning_count exists invalidate the pruning_factor
    if pruning_count is not None:
        assert isinstance(pruning_count, int) and pruning_count >= 1
        pruning_factor = None
    copied_model = model.clone()
    loss_val, loss_test = baselines
    selected_units = {}
    input_shapes = {
        l_name: getattr(model, l_name + '_ts').xshape
        for l_name in layers2prune
    }
    measurements = get_pruning_measurements(copied_model, subset_val2,
                                            layers2prune)
    (scores, mean_values, l2_norms) = measurements
    for scoring, is_bp in pruning_methods:
        assert (scoring in ['norm', 'abs_mrs', 'abs_rs', 'mrs', 'rs', 'rand'])
        copied_model = model.clone()
        scalar_summary_tag = 'pruning_penalty%s_%s' % ('_bp' if is_bp else '',
                                                       scoring)
        logging.info('Pruning following layers: %s, Using %s %s bias_prop.',
                     layers2prune, scoring, 'with' if is_bp else 'without')
        selected_units_scoring = prune_model_with_scores(
            copied_model, scores[scoring], is_bp, layers2prune, pruning_factor,
            pruning_count, mean_values, input_shapes)
        # There's going to be two pass for a given `scoring` with and without `bp`.
        # We will record them only once, since they are equal.
        if scoring not in selected_units:
            selected_units[scoring] = selected_units_scoring
        if n_retrain > 0:
            f_retrain(copied_model, n_retrain)
        # Setting training=True since test uses running average and revives pruned,
        # units.
        loss_new, _, _ = train_utils.cross_entropy_loss(copied_model,
                                                        subset_val,
                                                        training=True)
        tf.summary.scalar(scalar_summary_tag, loss_new - loss_val)
        # Setting training=True, otherwise BatchNorm uses the accumulated mean and
        # std during forward propagation, which causes pruned units to generate
        # non-zero constants.
        loss_new, _, _ = train_utils.cross_entropy_loss(copied_model,
                                                        subset_test,
                                                        training=True)
        tf.summary.scalar(scalar_summary_tag + '_test', loss_new - loss_test)
    return selected_units, mean_values, l2_norms
예제 #12
0
    def prune_one_unit(self,
                       pruning_pool,
                       baselines=None,
                       normalized_scores=True,
                       pruning_method=None,
                       is_bp=None):
        """Picks a layer and prunes a single unit using the scoring function.

    Args:
      pruning_pool: list, of layers that are considered for pruning.
      baselines: dict, if exists, subtracts the given constant from the scores
        of individual layers. The keys should a subset of pruning_pool.
      normalized_scores: bool, if True the scores are normalized with l2 norm.
      pruning_method: str, from ['norm', 'mrs', 'rs', 'rand', 'abs_mrs', 'rs'].
        If given, overwrites the default value.
      is_bp: bool, if True Mean Replacement Pruning is used and bias propagation
        is made. If given, overwrites the default value.

    Raises:
      AssertionError: if the arguments provided doesn't match specs.
    """
        pruning_method = pruning_method if pruning_method else self.pruning_method
        is_bp = is_bp if is_bp else self.is_bp
        if pruning_method not in ALL_SCORING_FUNCTIONS:
            raise ValueError('%s is not one of %s' %
                             (pruning_method, ALL_SCORING_FUNCTIONS))
        if baselines is None:
            baselines = {}
        logging.info('Prunning with: %s, is_bp: %s', pruning_method, is_bp)

        # Calculating the scoring function/mean value.
        is_abs = pruning_method.startswith('abs')
        is_mrs = pruning_method.endswith('mrs')
        is_rs = pruning_method.endswith('rs') and not is_mrs
        is_grad = is_mrs or is_rs
        train_utils.cross_entropy_loss(
            self.model,
            self.subset_val,
            training=False,
            compute_mean_replacement_saliency=is_mrs,
            compute_removal_saliency=is_rs,
            is_abs=is_abs,
            aggregate_values=True,
            run_gradient=is_grad)
        scores = {}
        mean_values = {}
        smallest_score = None
        smallest_l_name = None
        smallest_nprune = None

        for l_name in pruning_pool:
            l_ts = getattr(self.model, l_name + '_ts')
            l = getattr(self.model, l_name)
            mean_values[l_name] = l_ts.get_saved_values('mean')
            # Make sure the masks are applied after last gradient update. Note
            # that this is necessary for `norm` functions, since it doesn't call the
            # model and therefore the masks are not applied.
            l.apply_masks()
            if pruning_method == 'rand':
                scores[l_name] = unitscorers.random_score(
                    l.get_layer().weights[0])
            elif pruning_method == 'norm':
                scores[l_name] = unitscorers.norm_score(
                    l.get_layer().weights[0])
            else:
                # mrs or rs.
                score_name = 'rs' if is_rs else 'mrs'
                scores[l_name] = l_ts.get_saved_values(score_name)
            if normalized_scores:
                scores[l_name] /= tf.norm(scores[l_name])
            baseline_score = baselines.get(l_name, 0)
            if baseline_score != 0:
                # Regularizing the scores with c_flop weights.
                scores[l_name] -= baseline_score
            # If there is an existing mask we have to make sure pruned connections
            # are indicated. Let's set them to very small negative number (-1e10).
            # Note that the elements of `l.mask_bias` consist of zeros and ones only.
            if l.mask_bias is not None:
                # Setting the scores of the pruned units to zero.
                scores[l_name] = scores[l_name] * l.mask_bias
                # Setting the scores of the pruned units to -1e10.
                scores[l_name] += -1e10 * (1 - l.mask_bias)
                # Number of previously pruned units.
                n_pruned = tf.count_nonzero(l.mask_bias - 1).numpy()
                layer_smallest_score = tf.reduce_min(
                    tf.boolean_mask(scores[l_name], l.mask_bias)).numpy()
                # Do not prune the last unit.
                if tf.equal(n_pruned + 1, tf.size(l.mask_bias)):
                    continue
            else:
                n_pruned = 0
                layer_smallest_score = tf.reduce_min(scores[l_name]).numpy()

            logging.info('Layer:%s, min:%f', l_name, layer_smallest_score)
            if smallest_score is None or (layer_smallest_score <
                                          smallest_score):
                smallest_score = layer_smallest_score
                smallest_l_name = l_name
                # We want to prune one more than before.
                smallest_nprune = n_pruned + 1
        logging.info('UNIT_PRUNED, layer:%s, n_pruned:%d', smallest_l_name,
                     smallest_nprune)
        mean_values = {smallest_l_name: mean_values[smallest_l_name]}
        scores = {smallest_l_name: scores[smallest_l_name]}
        input_shapes = {
            smallest_l_name: getattr(self.model,
                                     smallest_l_name + '_ts').xshape
        }
        layers2prune = [smallest_l_name]
        prune_model_with_scores(self.model, scores, is_bp, layers2prune, None,
                                smallest_nprune, mean_values, input_shapes)