Example #1
0
 def testDatasetName(self):
     data.get_datasets(dataset_name='cifar10')
     data.get_datasets(dataset_name='imagenet')
     with self.assertRaises(AssertionError):
         data.get_datasets(dataset_name='cifar')
     with self.assertRaises(AssertionError):
         data.get_datasets(dataset_name='shvc')
Example #2
0
 def testCUB200(self):
     bs = 4
     val_size = 15
     eval_size = 12
     datasets = data.get_datasets('cub200',
                                  eval_size=eval_size,
                                  val_size=val_size,
                                  num_parallel_calls=1,
                                  shuffle_size=1,
                                  batch_size=bs)
     (dataset_train, _, _, _, _) = datasets
     x, y = next(dataset_train.__iter__())
     self.assertEqual(x.shape, [bs, 224, 224, 3])
     self.assertEqual(y.shape, [bs])
Example #3
0
    def testDefaultArgs(self):
        bs = 32
        val_size = eval_size = 1000
        datasets = data.get_datasets(shuffle_size=1)
        (dataset_train, dataset_test, subset_val, subset_test,
         subset_val2) = datasets
        x, y = next(dataset_train.__iter__())
        self.assertEqual(x.shape, [bs, 32, 32, 3])
        self.assertEqual(y.shape, [bs])
        self.assertLessEqual(tf.reduce_max(x), 1.0)
        self.assertGreaterEqual(tf.reduce_min(x), -1.)
        x, y = next(dataset_test.__iter__())
        self.assertLessEqual(tf.reduce_max(x), 1.0)
        self.assertGreaterEqual(tf.reduce_min(x), -1.)
        self.assertEqual(x.shape, [bs, 32, 32, 3])
        self.assertEqual(y.shape, [bs])

        c_iterator = subset_val.__iter__()
        x, y = next(c_iterator)
        self.assertLessEqual(tf.reduce_max(x), 1.0)
        self.assertGreaterEqual(tf.reduce_min(x), -1.0)
        self.assertEqual(x.shape, [val_size, 32, 32, 3])
        self.assertEqual(y.shape, [val_size])
        # Since chunk_size=None, it should only have one batch.
        with self.assertRaises(StopIteration):
            next(c_iterator)

        c_iterator = subset_val2.__iter__()
        x2, y2 = next(c_iterator)
        self.assertLessEqual(tf.reduce_max(x2), 1.0)
        self.assertGreaterEqual(tf.reduce_min(x2), -1.)
        self.assertEqual(x2.shape, [eval_size, 32, 32, 3])
        self.assertEqual(y2.shape, [eval_size])
        # Check that the subset's are disjoint.
        self.assertNotAllClose(x, x2)
        # Since chunk_size=None, it should only have one batch.
        with self.assertRaises(StopIteration):
            next(c_iterator)

        c_iterator = subset_test.__iter__()
        x, y = next(c_iterator)
        self.assertLessEqual(tf.reduce_max(x), 1.0)
        self.assertGreaterEqual(tf.reduce_min(x), -1.)
        self.assertEqual(x.shape, [eval_size, 32, 32, 3])
        self.assertEqual(y.shape, [eval_size])
        with self.assertRaises(StopIteration):
            next(c_iterator)
Example #4
0
 def testVgg(self, mock_pp_i):
     mock_pp_i.side_effect = lambda a: a
     bs = 4
     val_size = 15
     eval_size = 12
     datasets = data.get_datasets('imagenet_vgg',
                                  eval_size=eval_size,
                                  val_size=val_size,
                                  num_parallel_calls=1,
                                  shuffle_size=1,
                                  batch_size=bs)
     (dataset_train, _, _, _, _) = datasets
     x, y = next(dataset_train.__iter__())
     self.assertEqual(x.shape, [bs, 224, 224, 3])
     self.assertEqual(y.shape, [bs])
     # Due to data augmentation the max can be slightly bigger than 1.0.
     self.assertLessEqual(tf.reduce_max(x), 1.5)
     self.assertGreaterEqual(tf.reduce_min(x), -0.5)
     self.assertTrue(mock_pp_i.called)
     self.assertTrue(mock_pp_i.call_count, bs)
Example #5
0
 def testCustomImagenetArgs(self):
     bs = 4
     val_size = 15
     eval_size = 12
     chunk_size = 5
     datasets = data.get_datasets('imagenet',
                                  eval_size=eval_size,
                                  val_size=val_size,
                                  batch_size=bs,
                                  shuffle_size=1,
                                  chunk_size=chunk_size)
     (dataset_train, dataset_test, subset_val, subset_test, _) = datasets
     x, y = next(dataset_train.__iter__())
     self.assertEqual(x.shape, [bs, 224, 224, 3])
     self.assertEqual(y.shape, [bs])
     x, y = next(dataset_test.__iter__())
     self.assertEqual(x.shape, [bs, 224, 224, 3])
     self.assertEqual(y.shape, [bs])
     c_iterator = subset_val.__iter__()
     x, y = next(c_iterator)
     self.assertEqual(x.shape, [chunk_size, 224, 224, 3])
     self.assertEqual(y.shape, [chunk_size])
     # Let us consume all batches.
     for _ in range((val_size - 1) // chunk_size):
         x, y = next(c_iterator)
     # Since we iterated over all batches, we should get an exception.
     with self.assertRaises(StopIteration):
         next(c_iterator)
     c_iterator = subset_test.__iter__()
     x, y = next(c_iterator)
     self.assertEqual(x.shape, [chunk_size, 224, 224, 3])
     self.assertEqual(y.shape, [chunk_size])
     for _ in range((eval_size - 1) // chunk_size):
         x, y = next(c_iterator)
     last_batch_shape = eval_size % chunk_size
     self.assertEqual(x.shape, [last_batch_shape, 224, 224, 3])
     self.assertEqual(y.shape, [last_batch_shape])
     with self.assertRaises(StopIteration):
         next(c_iterator)
Example #6
0
def prune_layer_and_eval(dataset_name='cifar10',
                         pruning_methods=(('rand', True), ),
                         model_dir=gin.REQUIRED,
                         l_name=gin.REQUIRED,
                         n_exps=gin.REQUIRED,
                         val_size=gin.REQUIRED,
                         max_pruning_fraction=gin.REQUIRED):
    """Loads and prunes a model with various sparsity targets.

  This function assumes that the `seed_dir` exists
  Args:
    dataset_name: str, either 'cifar10' or 'imagenet'.
    pruning_methods: iterator of tuples, (scoring, is_bp) where `is_bp`
      is a boolean indicating the usage of Mean Replacement Pruning. Scoring is
      a string from ['norm', 'mrs', 'rs', 'rand', 'abs_mrs', 'rs'].
        'norm': unitscorers.norm_score
        '{abs_}mrs': `compute_mean_replacement_saliency=True` for `TaylorScorer`
          layer. if {abs_} prefix exists absolute value used before aggregation.
          i.e. is_abs=True.
        '{abs_}rs': `compute_removal_saliency=True` for `TaylorScorer` layer. if
          {abs_} prefix exists absolute value used before aggregation. i.e.
          is_abs=True.
        'rand': unitscorers.random_score is_bp; bool, if True, mean value of the
          units are propagated to the next layer prior to the pruning. `bp` for
          bias_propagation.
    model_dir: str, Path to the checkpoint directory.
    l_name: str, a valid layer name from the model loaded.
    n_exps: int, number of pruning experiments to be made. This number is used
      generate pruning counts for different experiments.
    val_size: int, size for the first dataset, passed to the `get_datasets`.
    max_pruning_fraction: float, max sparsity for pruning. Multiplying this
      number with the total number of units, we would get the upper limit for
      the pruning_count.

  Raises:
    AssertionError: when no checkpoint is found.
    ValueError: when the scoring function key is not valid.
    OSError: when there is no checkpoint found.
  """
    logging.info('Looking checkpoint at: %s', model_dir)
    latest_cpkt = tf.train.latest_checkpoint(model_dir)
    if not latest_cpkt:
        raise OSError('No checkpoint found in %s' % model_dir)
    logging.info('Using latest checkpoint at %s', latest_cpkt)
    model = model_load.get_model(load_path=latest_cpkt,
                                 dataset_name=dataset_name)
    datasets = data.get_datasets(dataset_name=dataset_name, val_size=val_size)
    _, _, subset_val, subset_test, subset_val2 = datasets
    input_shapes = {l_name: getattr(model, l_name + '_ts').xshape}
    layers2prune = [l_name]
    measurements = pruner.get_pruning_measurements(model, subset_val,
                                                   layers2prune)
    (all_scores, mean_values, _) = measurements
    for scoring, is_bp in pruning_methods:
        if scoring not in pruner.ALL_SCORING_FUNCTIONS:
            raise ValueError('%s is not one of %s' %
                             (scoring, pruner.ALL_SCORING_FUNCTIONS))
        scores = all_scores[scoring]
        d_path = os.path.join(
            FLAGS.outdir,
            '%d-%s-%s-%s.pickle' % (val_size, l_name, scoring, str(is_bp)))
        logging.info(d_path)
        if tf.gfile.Exists(d_path):
            logging.warning('File %s exists, skipping.', d_path)
        else:
            ls_train_loss = []
            ls_train_acc = []
            ls_test_loss = []
            ls_test_acc = []

            n_units = input_shapes[l_name][-1].value
            n_unit_pruned_max = n_units * max_pruning_fraction
            c_slice = np.linspace(0, n_unit_pruned_max, n_exps, dtype=np.int32)
            logging.info('Layer:%s, n_units:%d, c_slice:%s', l_name, n_units,
                         str(c_slice))
            for pruning_count in c_slice:
                # Cast from np.int32 to int.
                pruning_count = int(pruning_count)
                copied_model = model.clone()
                pruning_factor = None
                pruner.prune_model_with_scores(copied_model, scores, is_bp,
                                               layers2prune, pruning_factor,
                                               pruning_count, mean_values,
                                               input_shapes)
                test_loss, test_acc, _ = train_utils.cross_entropy_loss(
                    copied_model, subset_test, calculate_accuracy=True)
                train_loss, train_acc, _ = train_utils.cross_entropy_loss(
                    copied_model, subset_val2, calculate_accuracy=True)
                logging.info('is_bp: %s, n: %d, test_loss%f, train_loss:%f',
                             str(is_bp), pruning_count, test_loss, train_loss)
                ls_train_loss.append(train_loss.numpy())
                ls_test_loss.append(test_loss.numpy())
                ls_test_acc.append(test_acc.numpy())
                ls_train_acc.append(train_acc.numpy())
            utils.pickle_object(
                (ls_train_loss, ls_train_acc, ls_test_loss, ls_test_acc),
                d_path)
def prune_and_finetune_model(pruning_schedule=gin.REQUIRED,
                             model_dir=gin.REQUIRED,
                             dataset_name='cifar10',
                             checkpoint_interval=5,
                             log_interval=5,
                             n_finetune=100,
                             epochs=20,
                             lr=1e-4,
                             momentum=0.9):
    """Loads and prunes the model layer by layer according to the schedule.

  The model is finetuned between pruning tasks (as given in the schedule).

  Args:
    pruning_schedule: list<str, float>, where the str is a valid layer name and
      the float is the pruning fraction of that layer. Layers are pruned in
      the order they are given and `n_finetune` steps taken in between.
    model_dir: str, Path to the checkpoint directory.
    dataset_name: str, either 'cifar10' or 'imagenet'.
    checkpoint_interval: int, number of epochs between two checkpoints.
    log_interval: int, number of steps between two logging events.
    n_finetune: int, number of steps between two pruning steps.
    epochs: int, total number of epochs to run.
    lr: float, learning rate for the fine-tuning steps.
    momentum: float, momentum multiplier for the fine-tuning steps.
  Raises:
    ValueError: when the n_finetune is not positive.
    OSError: when there is no checkpoint found.
  """
    if n_finetune <= 0:
        raise ValueError('n_finetune must be positive: given %d' % n_finetune)
    optimizer = tf.keras.optimizers.SGD(learning_rate=lr, momentum=momentum)
    logging.info('Using outdir: %s', model_dir)
    latest_cpkt = tf.train.latest_checkpoint(model_dir)
    if not latest_cpkt:
        raise OSError('No checkpoint found in %s' % model_dir)
    logging.info('Using latest checkpoint at %s', latest_cpkt)
    model = model_load.get_model(load_path=latest_cpkt,
                                 dataset_name=dataset_name)
    logging.info('Model init-config: %s', model.init_config)
    logging.info('Model forward chain: %s', str(model.forward_chain))
    datasets = data.get_datasets(dataset_name=dataset_name)
    dataset_train, dataset_test, subset_val, subset_test, subset_val2 = datasets

    unit_pruner = pruner.UnitPruner(model, subset_val)
    step_counter = optimizer.iteration
    tf.summary.experimental.set_step(step_counter)
    current_epoch = tf.Variable(1)
    current_layer_index = tf.Variable(0)
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     model=model,
                                     current_epoch=current_epoch,
                                     current_layer_index=current_layer_index)
    latest_cpkt = tf.train.latest_checkpoint(FLAGS.outdir)
    if latest_cpkt:
        logging.info('Using latest checkpoint at %s', latest_cpkt)
        # Restore variables on creation if a checkpoint exists.
        checkpoint.restore(latest_cpkt)
        logging.info('Resuming with epoch: %d', current_epoch.numpy())
    c_epoch = current_epoch.numpy()
    c_layer_index = current_layer_index.numpy()
    # Starting from the first batch, we perform pruning every `n_finetune` step.
    # Layers pruned one by one according to the pruning schedule given.

    while c_epoch <= epochs:
        logging.info('Starting Epoch: %d', c_epoch)
        for (x, y) in dataset_train:
            # Every `n_finetune` step perform pruning.
            if (step_counter.numpy() % n_finetune == 0
                    and c_layer_index < len(pruning_schedule)):
                logging.info('Pruning at iteration: %d', step_counter.numpy())
                l_name, pruning_factor = pruning_schedule[c_layer_index]
                unit_pruner.prune_layer(l_name, pruning_factor=pruning_factor)

                train_utils.log_loss_acc(model, subset_val2, subset_test)
                train_utils.log_sparsity(model)
                # Re-init optimizer and therefore remove previous momentum.
                optimizer = tf.keras.optimizers.SGD(learning_rate=lr,
                                                    momentum=momentum)
                c_layer_index += 1
                current_layer_index.assign(c_layer_index)
            else:
                if step_counter.numpy() % log_interval == 0:
                    logging.info('Iteration: %d', step_counter.numpy())
                    train_utils.log_loss_acc(model, subset_val2, subset_test)
                    train_utils.log_sparsity(model)
            with tf.GradientTape() as tape:
                loss_train, _, _ = cross_entropy_loss(model, (x, y),
                                                      training=True)
            grads = tape.gradient(loss_train, model.variables)
            # Updating the model.
            optimizer.apply_gradients(list(zip(grads, model.variables)),
                                      global_step=step_counter)
            if step_counter.numpy() % log_interval == 0:
                tf.summary.scalar('loss_train', loss_train)
                tf.summary.image('x', x, max_outputs=1)
        # End of an epoch.
        c_epoch += 1
        current_epoch.assign(c_epoch)
        # Save every n OR after last epoch.
        if (tf.equal((current_epoch - 1) % checkpoint_interval, 0)
                or c_epoch > epochs):
            # Re-init checkpoint to ensure the masks are captured. The reason for
            # this is that the masks are initially not generated.
            checkpoint = tf.train.Checkpoint(
                optimizer=optimizer,
                model=model,
                current_epoch=current_epoch,
                current_layer_index=current_layer_index)
            logging.info('Checkpoint after epoch: %d', c_epoch - 1)
            checkpoint.save(
                os.path.join(FLAGS.outdir, 'ckpt-%d' % (c_epoch - 1)))

    # Test model
    test_loss, test_acc, n_samples = cross_entropy_loss(
        model, dataset_test, calculate_accuracy=True)
    tf.summary.scalar('test_loss_all', test_loss)
    tf.summary.scalar('test_acc_all', test_acc)
    logging.info(
        'Overall_test_loss: %.4f, Overall_test_acc: %.4f, n_samples: %d',
        test_loss, test_acc, n_samples)
Example #8
0
def train_model(dataset_name='cifar10',
                checkpoint_every_n_epoch=5,
                log_interval=1000,
                epochs=10,
                lr=1e-2,
                lr_drop_iter=1500,
                lr_decay=0.5,
                momentum=0.9,
                seed=8):
  """Trains the model with regular logging and checkpoints.

  Args:
    dataset_name: str, either 'cifar10' or 'imagenet'.
    checkpoint_every_n_epoch: int, number of epochs between two checkpoints.
    log_interval: int, number of steps between two logging events.
    epochs: int, epoch to train with.
    lr: float, learning rate for the fine-tuning steps.
    lr_drop_iter: int, iteration between two consequtive lr drop.
    lr_decay: float: multiplier for step learning rate reduction.
    momentum: float, momentum multiplier for the fine-tuning steps.
    seed: int, random seed to be set to produce reproducible experiments.

  Raises:
    AssertionError: when the args doesn't match the specs.
  """
  assert dataset_name in ['cifar10', 'imagenet', 'cub200', 'imagenet_vgg']
  tf.random.set_seed(seed)
  # The model is configured through gin parameters.
  model = model_load.get_model(dataset_name=dataset_name)

  # model.call = tf.contrib.eager.defun(model.call)
  datasets = data.get_datasets(dataset_name=dataset_name)
  dataset_train, dataset_test, _, subset_test, subset_val = datasets
  lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
      lr, decay_steps=lr_drop_iter, decay_rate=lr_decay, staircase=True)
  optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=momentum)
  step_counter = optimizer.iterations
  tf.summary.experimental.set_step(step_counter)
  logging.info('Model init-config: %s', model.init_config)
  logging.info('Model forward chain: %s', str(model.forward_chain))

  current_epoch = tf.Variable(1)
  # Create checkpoint object TODO check whether you need ckpt-prefix.
  checkpoint = tf.train.Checkpoint(
      optimizer=optimizer,
      model=model,
      current_epoch=current_epoch)
  # No limits basically by setting to `n_units_target`.
  checkpoint_manager = tf.train.CheckpointManager(
      checkpoint, directory=FLAGS.outdir, max_to_keep=None)

  latest_cpkt = checkpoint_manager.latest_checkpoint
  if latest_cpkt:
    logging.info('Using latest checkpoint at %s', latest_cpkt)
    # Restore variables on creation if a checkpoint exists.
    checkpoint.restore(latest_cpkt)
    logging.info('Resuming with epoch: %d', current_epoch.numpy())
  c_epoch = current_epoch.numpy()
  while c_epoch <= epochs:
    logging.info('Starting Epoch:%d', c_epoch)
    for (x, y) in dataset_train:
      if step_counter % log_interval == 0:
        train_utils.log_loss_acc(model, subset_val, subset_test)
        tf.summary.image('x', x, max_outputs=1)
        logging.info('Iteration:%d', step_counter.numpy())
      with tf.GradientTape() as tape:
        loss_train, _, _ = cross_entropy_loss(model, (x, y), training=True)
      grads = tape.gradient(loss_train, model.variables)
      # Updating the model.
      optimizer.apply_gradients(zip(grads, model.variables))
      tf.summary.scalar('loss_train', loss_train)
      tf.summary.scalar('lr', optimizer.lr(step_counter))
    # End of an epoch.
    c_epoch += 1
    current_epoch.assign(c_epoch)
    # Save every n OR after last epoch.
    if (tf.equal((current_epoch - 1) % checkpoint_every_n_epoch, 0) or
        c_epoch > epochs):
      logging.info('Checkpoint after epoch: %d', c_epoch - 1)
      checkpoint_manager.save(checkpoint_number=c_epoch - 1)
  test_loss, test_acc, n_samples = cross_entropy_loss(
      model, dataset_test, calculate_accuracy=True)
  tf.summary.scalar('test_loss_all', test_loss)
  tf.summary.scalar('test_acc_all', test_acc)
  logging.info('Overall_test_loss:%.4f, Overall_test_acc:%.4f, n_samples:%d',
               test_loss, test_acc, n_samples)
Example #9
0
def prune_and_finetune_model(dataset_name='imagenet_vgg',
                             flop_regularizer=0,
                             n_units_target=4000,
                             checkpoint_interval=5,
                             log_interval=5,
                             n_finetune=100,
                             lr=1e-4,
                             momentum=0.9,
                             seed=8):
    """Trains the model with regular logging and checkpoints.

  Args:
    dataset_name: str, dataset to train on.
    flop_regularizer: float, multiplier for the flop regularization. If 0, no
      regularization is made during pruning.
    n_units_target: int, number of unit to prune.
    checkpoint_interval: int, number of epochs between two checkpoints.
    log_interval: int, number of steps between two logging events.
    n_finetune: int, number of steps between two pruning steps. Starting from
      the first iteration we prune 1 unit every `n_finetune` gradient update.
    lr: float, learning rate for the fine-tuning steps.
    momentum: float, momentum multiplier for the fine-tuning steps.
    seed: int, random seed to be set to produce reproducible experiments.
  Raises:
    ValueError: when the n_finetune is not positive.
  """
    if n_finetune <= 0:
        raise ValueError('n_finetune must be positive: given %d' % n_finetune)
    tf.random.set_seed(seed)
    optimizer = tf.keras.optimizers.SGD(lr, momentum=momentum)
    # imagenet_vgg->imagenet
    dataset_basename = dataset_name.split('_')[0]
    model = model_load.get_model(dataset_name=dataset_basename)
    # Uncomment following if your model is (defunable).
    # model.call = tf.function(model.call)
    datasets = data.get_datasets(dataset_name=dataset_name)
    (dataset_train, _, subset_val, subset_test, subset_val2) = datasets
    logging.info('Model init-config: %s', model.init_config)
    logging.info('Model forward chain: %s', str(model.forward_chain))

    unit_pruner = pruner.UnitPruner(model, subset_val)
    # We prune all conv layers.
    pruning_pool = _all_vgg_conv_layers
    baselines = {
        l_name: c_flop * flop_regularizer
        for l_name, c_flop in zip(pruning_pool, _vgg16_flop_regularizition)
    }

    step_counter = optimizer.iterations
    tf.summary.experimental.set_step(step_counter)
    c_pruning_step = tf.Variable(1)
    # Create checkpoint object TODO check whether you need ckpt-prefix.
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     model=model,
                                     c_pruning_step=c_pruning_step)
    # No limits basically by setting to `n_units_target`.
    checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                    directory=FLAGS.outdir,
                                                    max_to_keep=None)

    latest_cpkt = checkpoint_manager.latest_checkpoint
    if latest_cpkt:
        logging.info('Using latest checkpoint at %s', latest_cpkt)
        # Restore variables on creation if a checkpoint exists.
        checkpoint.restore(latest_cpkt)
        logging.info('Resuming with pruning step: %d', c_pruning_step.numpy())
    pruning_step = c_pruning_step.numpy()
    while pruning_step <= n_units_target:
        for (x, y) in dataset_train:
            # Every `n_finetune` step perform pruning.
            if step_counter.numpy() % n_finetune == 0:
                if pruning_step > n_units_target:
                    # This holds true when we prune last time and fine tune N many
                    # iterations. We would break and the while loop above would break,
                    # too.
                    break
                tf.logging.info('Pruning Step:%d', pruning_step)
                start = time.time()
                unit_pruner.prune_one_unit(pruning_pool, baselines=baselines)
                end = time.time()
                tf.logging.info(
                    '\nTrain time for Pruning Step #%d (step %d): %f',
                    pruning_step,
                    tf.train.get_or_create_global_step().numpy(), end - start)
                pruning_step += 1
                c_pruning_step.assign(pruning_step)
                if tf.equal((pruning_step - 1) % checkpoint_interval, 0):
                    checkpoint_manager.save()
            if step_counter.numpy() % log_interval == 0:
                train_utils.log_loss_acc(model, subset_val2, subset_test)
                train_utils.log_sparsity(model)
            with tf.GradientTape() as tape:
                loss_train, _, _ = cross_entropy_loss(model, (x, y),
                                                      training=True)
            grads = tape.gradient(loss_train, model.variables)
            # Updating the model.
            optimizer.apply_gradients(list(zip(grads, model.variables)),
                                      global_step=step_counter)
            if step_counter.numpy() % log_interval == 0:
                tf.summary.scalar('loss_train', loss_train)
                tf.summary.image('x', x, max_outputs=1)