def testDatasetName(self): data.get_datasets(dataset_name='cifar10') data.get_datasets(dataset_name='imagenet') with self.assertRaises(AssertionError): data.get_datasets(dataset_name='cifar') with self.assertRaises(AssertionError): data.get_datasets(dataset_name='shvc')
def testCUB200(self): bs = 4 val_size = 15 eval_size = 12 datasets = data.get_datasets('cub200', eval_size=eval_size, val_size=val_size, num_parallel_calls=1, shuffle_size=1, batch_size=bs) (dataset_train, _, _, _, _) = datasets x, y = next(dataset_train.__iter__()) self.assertEqual(x.shape, [bs, 224, 224, 3]) self.assertEqual(y.shape, [bs])
def testDefaultArgs(self): bs = 32 val_size = eval_size = 1000 datasets = data.get_datasets(shuffle_size=1) (dataset_train, dataset_test, subset_val, subset_test, subset_val2) = datasets x, y = next(dataset_train.__iter__()) self.assertEqual(x.shape, [bs, 32, 32, 3]) self.assertEqual(y.shape, [bs]) self.assertLessEqual(tf.reduce_max(x), 1.0) self.assertGreaterEqual(tf.reduce_min(x), -1.) x, y = next(dataset_test.__iter__()) self.assertLessEqual(tf.reduce_max(x), 1.0) self.assertGreaterEqual(tf.reduce_min(x), -1.) self.assertEqual(x.shape, [bs, 32, 32, 3]) self.assertEqual(y.shape, [bs]) c_iterator = subset_val.__iter__() x, y = next(c_iterator) self.assertLessEqual(tf.reduce_max(x), 1.0) self.assertGreaterEqual(tf.reduce_min(x), -1.0) self.assertEqual(x.shape, [val_size, 32, 32, 3]) self.assertEqual(y.shape, [val_size]) # Since chunk_size=None, it should only have one batch. with self.assertRaises(StopIteration): next(c_iterator) c_iterator = subset_val2.__iter__() x2, y2 = next(c_iterator) self.assertLessEqual(tf.reduce_max(x2), 1.0) self.assertGreaterEqual(tf.reduce_min(x2), -1.) self.assertEqual(x2.shape, [eval_size, 32, 32, 3]) self.assertEqual(y2.shape, [eval_size]) # Check that the subset's are disjoint. self.assertNotAllClose(x, x2) # Since chunk_size=None, it should only have one batch. with self.assertRaises(StopIteration): next(c_iterator) c_iterator = subset_test.__iter__() x, y = next(c_iterator) self.assertLessEqual(tf.reduce_max(x), 1.0) self.assertGreaterEqual(tf.reduce_min(x), -1.) self.assertEqual(x.shape, [eval_size, 32, 32, 3]) self.assertEqual(y.shape, [eval_size]) with self.assertRaises(StopIteration): next(c_iterator)
def testVgg(self, mock_pp_i): mock_pp_i.side_effect = lambda a: a bs = 4 val_size = 15 eval_size = 12 datasets = data.get_datasets('imagenet_vgg', eval_size=eval_size, val_size=val_size, num_parallel_calls=1, shuffle_size=1, batch_size=bs) (dataset_train, _, _, _, _) = datasets x, y = next(dataset_train.__iter__()) self.assertEqual(x.shape, [bs, 224, 224, 3]) self.assertEqual(y.shape, [bs]) # Due to data augmentation the max can be slightly bigger than 1.0. self.assertLessEqual(tf.reduce_max(x), 1.5) self.assertGreaterEqual(tf.reduce_min(x), -0.5) self.assertTrue(mock_pp_i.called) self.assertTrue(mock_pp_i.call_count, bs)
def testCustomImagenetArgs(self): bs = 4 val_size = 15 eval_size = 12 chunk_size = 5 datasets = data.get_datasets('imagenet', eval_size=eval_size, val_size=val_size, batch_size=bs, shuffle_size=1, chunk_size=chunk_size) (dataset_train, dataset_test, subset_val, subset_test, _) = datasets x, y = next(dataset_train.__iter__()) self.assertEqual(x.shape, [bs, 224, 224, 3]) self.assertEqual(y.shape, [bs]) x, y = next(dataset_test.__iter__()) self.assertEqual(x.shape, [bs, 224, 224, 3]) self.assertEqual(y.shape, [bs]) c_iterator = subset_val.__iter__() x, y = next(c_iterator) self.assertEqual(x.shape, [chunk_size, 224, 224, 3]) self.assertEqual(y.shape, [chunk_size]) # Let us consume all batches. for _ in range((val_size - 1) // chunk_size): x, y = next(c_iterator) # Since we iterated over all batches, we should get an exception. with self.assertRaises(StopIteration): next(c_iterator) c_iterator = subset_test.__iter__() x, y = next(c_iterator) self.assertEqual(x.shape, [chunk_size, 224, 224, 3]) self.assertEqual(y.shape, [chunk_size]) for _ in range((eval_size - 1) // chunk_size): x, y = next(c_iterator) last_batch_shape = eval_size % chunk_size self.assertEqual(x.shape, [last_batch_shape, 224, 224, 3]) self.assertEqual(y.shape, [last_batch_shape]) with self.assertRaises(StopIteration): next(c_iterator)
def prune_layer_and_eval(dataset_name='cifar10', pruning_methods=(('rand', True), ), model_dir=gin.REQUIRED, l_name=gin.REQUIRED, n_exps=gin.REQUIRED, val_size=gin.REQUIRED, max_pruning_fraction=gin.REQUIRED): """Loads and prunes a model with various sparsity targets. This function assumes that the `seed_dir` exists Args: dataset_name: str, either 'cifar10' or 'imagenet'. pruning_methods: iterator of tuples, (scoring, is_bp) where `is_bp` is a boolean indicating the usage of Mean Replacement Pruning. Scoring is a string from ['norm', 'mrs', 'rs', 'rand', 'abs_mrs', 'rs']. 'norm': unitscorers.norm_score '{abs_}mrs': `compute_mean_replacement_saliency=True` for `TaylorScorer` layer. if {abs_} prefix exists absolute value used before aggregation. i.e. is_abs=True. '{abs_}rs': `compute_removal_saliency=True` for `TaylorScorer` layer. if {abs_} prefix exists absolute value used before aggregation. i.e. is_abs=True. 'rand': unitscorers.random_score is_bp; bool, if True, mean value of the units are propagated to the next layer prior to the pruning. `bp` for bias_propagation. model_dir: str, Path to the checkpoint directory. l_name: str, a valid layer name from the model loaded. n_exps: int, number of pruning experiments to be made. This number is used generate pruning counts for different experiments. val_size: int, size for the first dataset, passed to the `get_datasets`. max_pruning_fraction: float, max sparsity for pruning. Multiplying this number with the total number of units, we would get the upper limit for the pruning_count. Raises: AssertionError: when no checkpoint is found. ValueError: when the scoring function key is not valid. OSError: when there is no checkpoint found. """ logging.info('Looking checkpoint at: %s', model_dir) latest_cpkt = tf.train.latest_checkpoint(model_dir) if not latest_cpkt: raise OSError('No checkpoint found in %s' % model_dir) logging.info('Using latest checkpoint at %s', latest_cpkt) model = model_load.get_model(load_path=latest_cpkt, dataset_name=dataset_name) datasets = data.get_datasets(dataset_name=dataset_name, val_size=val_size) _, _, subset_val, subset_test, subset_val2 = datasets input_shapes = {l_name: getattr(model, l_name + '_ts').xshape} layers2prune = [l_name] measurements = pruner.get_pruning_measurements(model, subset_val, layers2prune) (all_scores, mean_values, _) = measurements for scoring, is_bp in pruning_methods: if scoring not in pruner.ALL_SCORING_FUNCTIONS: raise ValueError('%s is not one of %s' % (scoring, pruner.ALL_SCORING_FUNCTIONS)) scores = all_scores[scoring] d_path = os.path.join( FLAGS.outdir, '%d-%s-%s-%s.pickle' % (val_size, l_name, scoring, str(is_bp))) logging.info(d_path) if tf.gfile.Exists(d_path): logging.warning('File %s exists, skipping.', d_path) else: ls_train_loss = [] ls_train_acc = [] ls_test_loss = [] ls_test_acc = [] n_units = input_shapes[l_name][-1].value n_unit_pruned_max = n_units * max_pruning_fraction c_slice = np.linspace(0, n_unit_pruned_max, n_exps, dtype=np.int32) logging.info('Layer:%s, n_units:%d, c_slice:%s', l_name, n_units, str(c_slice)) for pruning_count in c_slice: # Cast from np.int32 to int. pruning_count = int(pruning_count) copied_model = model.clone() pruning_factor = None pruner.prune_model_with_scores(copied_model, scores, is_bp, layers2prune, pruning_factor, pruning_count, mean_values, input_shapes) test_loss, test_acc, _ = train_utils.cross_entropy_loss( copied_model, subset_test, calculate_accuracy=True) train_loss, train_acc, _ = train_utils.cross_entropy_loss( copied_model, subset_val2, calculate_accuracy=True) logging.info('is_bp: %s, n: %d, test_loss%f, train_loss:%f', str(is_bp), pruning_count, test_loss, train_loss) ls_train_loss.append(train_loss.numpy()) ls_test_loss.append(test_loss.numpy()) ls_test_acc.append(test_acc.numpy()) ls_train_acc.append(train_acc.numpy()) utils.pickle_object( (ls_train_loss, ls_train_acc, ls_test_loss, ls_test_acc), d_path)
def prune_and_finetune_model(pruning_schedule=gin.REQUIRED, model_dir=gin.REQUIRED, dataset_name='cifar10', checkpoint_interval=5, log_interval=5, n_finetune=100, epochs=20, lr=1e-4, momentum=0.9): """Loads and prunes the model layer by layer according to the schedule. The model is finetuned between pruning tasks (as given in the schedule). Args: pruning_schedule: list<str, float>, where the str is a valid layer name and the float is the pruning fraction of that layer. Layers are pruned in the order they are given and `n_finetune` steps taken in between. model_dir: str, Path to the checkpoint directory. dataset_name: str, either 'cifar10' or 'imagenet'. checkpoint_interval: int, number of epochs between two checkpoints. log_interval: int, number of steps between two logging events. n_finetune: int, number of steps between two pruning steps. epochs: int, total number of epochs to run. lr: float, learning rate for the fine-tuning steps. momentum: float, momentum multiplier for the fine-tuning steps. Raises: ValueError: when the n_finetune is not positive. OSError: when there is no checkpoint found. """ if n_finetune <= 0: raise ValueError('n_finetune must be positive: given %d' % n_finetune) optimizer = tf.keras.optimizers.SGD(learning_rate=lr, momentum=momentum) logging.info('Using outdir: %s', model_dir) latest_cpkt = tf.train.latest_checkpoint(model_dir) if not latest_cpkt: raise OSError('No checkpoint found in %s' % model_dir) logging.info('Using latest checkpoint at %s', latest_cpkt) model = model_load.get_model(load_path=latest_cpkt, dataset_name=dataset_name) logging.info('Model init-config: %s', model.init_config) logging.info('Model forward chain: %s', str(model.forward_chain)) datasets = data.get_datasets(dataset_name=dataset_name) dataset_train, dataset_test, subset_val, subset_test, subset_val2 = datasets unit_pruner = pruner.UnitPruner(model, subset_val) step_counter = optimizer.iteration tf.summary.experimental.set_step(step_counter) current_epoch = tf.Variable(1) current_layer_index = tf.Variable(0) checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, current_epoch=current_epoch, current_layer_index=current_layer_index) latest_cpkt = tf.train.latest_checkpoint(FLAGS.outdir) if latest_cpkt: logging.info('Using latest checkpoint at %s', latest_cpkt) # Restore variables on creation if a checkpoint exists. checkpoint.restore(latest_cpkt) logging.info('Resuming with epoch: %d', current_epoch.numpy()) c_epoch = current_epoch.numpy() c_layer_index = current_layer_index.numpy() # Starting from the first batch, we perform pruning every `n_finetune` step. # Layers pruned one by one according to the pruning schedule given. while c_epoch <= epochs: logging.info('Starting Epoch: %d', c_epoch) for (x, y) in dataset_train: # Every `n_finetune` step perform pruning. if (step_counter.numpy() % n_finetune == 0 and c_layer_index < len(pruning_schedule)): logging.info('Pruning at iteration: %d', step_counter.numpy()) l_name, pruning_factor = pruning_schedule[c_layer_index] unit_pruner.prune_layer(l_name, pruning_factor=pruning_factor) train_utils.log_loss_acc(model, subset_val2, subset_test) train_utils.log_sparsity(model) # Re-init optimizer and therefore remove previous momentum. optimizer = tf.keras.optimizers.SGD(learning_rate=lr, momentum=momentum) c_layer_index += 1 current_layer_index.assign(c_layer_index) else: if step_counter.numpy() % log_interval == 0: logging.info('Iteration: %d', step_counter.numpy()) train_utils.log_loss_acc(model, subset_val2, subset_test) train_utils.log_sparsity(model) with tf.GradientTape() as tape: loss_train, _, _ = cross_entropy_loss(model, (x, y), training=True) grads = tape.gradient(loss_train, model.variables) # Updating the model. optimizer.apply_gradients(list(zip(grads, model.variables)), global_step=step_counter) if step_counter.numpy() % log_interval == 0: tf.summary.scalar('loss_train', loss_train) tf.summary.image('x', x, max_outputs=1) # End of an epoch. c_epoch += 1 current_epoch.assign(c_epoch) # Save every n OR after last epoch. if (tf.equal((current_epoch - 1) % checkpoint_interval, 0) or c_epoch > epochs): # Re-init checkpoint to ensure the masks are captured. The reason for # this is that the masks are initially not generated. checkpoint = tf.train.Checkpoint( optimizer=optimizer, model=model, current_epoch=current_epoch, current_layer_index=current_layer_index) logging.info('Checkpoint after epoch: %d', c_epoch - 1) checkpoint.save( os.path.join(FLAGS.outdir, 'ckpt-%d' % (c_epoch - 1))) # Test model test_loss, test_acc, n_samples = cross_entropy_loss( model, dataset_test, calculate_accuracy=True) tf.summary.scalar('test_loss_all', test_loss) tf.summary.scalar('test_acc_all', test_acc) logging.info( 'Overall_test_loss: %.4f, Overall_test_acc: %.4f, n_samples: %d', test_loss, test_acc, n_samples)
def train_model(dataset_name='cifar10', checkpoint_every_n_epoch=5, log_interval=1000, epochs=10, lr=1e-2, lr_drop_iter=1500, lr_decay=0.5, momentum=0.9, seed=8): """Trains the model with regular logging and checkpoints. Args: dataset_name: str, either 'cifar10' or 'imagenet'. checkpoint_every_n_epoch: int, number of epochs between two checkpoints. log_interval: int, number of steps between two logging events. epochs: int, epoch to train with. lr: float, learning rate for the fine-tuning steps. lr_drop_iter: int, iteration between two consequtive lr drop. lr_decay: float: multiplier for step learning rate reduction. momentum: float, momentum multiplier for the fine-tuning steps. seed: int, random seed to be set to produce reproducible experiments. Raises: AssertionError: when the args doesn't match the specs. """ assert dataset_name in ['cifar10', 'imagenet', 'cub200', 'imagenet_vgg'] tf.random.set_seed(seed) # The model is configured through gin parameters. model = model_load.get_model(dataset_name=dataset_name) # model.call = tf.contrib.eager.defun(model.call) datasets = data.get_datasets(dataset_name=dataset_name) dataset_train, dataset_test, _, subset_test, subset_val = datasets lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( lr, decay_steps=lr_drop_iter, decay_rate=lr_decay, staircase=True) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=momentum) step_counter = optimizer.iterations tf.summary.experimental.set_step(step_counter) logging.info('Model init-config: %s', model.init_config) logging.info('Model forward chain: %s', str(model.forward_chain)) current_epoch = tf.Variable(1) # Create checkpoint object TODO check whether you need ckpt-prefix. checkpoint = tf.train.Checkpoint( optimizer=optimizer, model=model, current_epoch=current_epoch) # No limits basically by setting to `n_units_target`. checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=FLAGS.outdir, max_to_keep=None) latest_cpkt = checkpoint_manager.latest_checkpoint if latest_cpkt: logging.info('Using latest checkpoint at %s', latest_cpkt) # Restore variables on creation if a checkpoint exists. checkpoint.restore(latest_cpkt) logging.info('Resuming with epoch: %d', current_epoch.numpy()) c_epoch = current_epoch.numpy() while c_epoch <= epochs: logging.info('Starting Epoch:%d', c_epoch) for (x, y) in dataset_train: if step_counter % log_interval == 0: train_utils.log_loss_acc(model, subset_val, subset_test) tf.summary.image('x', x, max_outputs=1) logging.info('Iteration:%d', step_counter.numpy()) with tf.GradientTape() as tape: loss_train, _, _ = cross_entropy_loss(model, (x, y), training=True) grads = tape.gradient(loss_train, model.variables) # Updating the model. optimizer.apply_gradients(zip(grads, model.variables)) tf.summary.scalar('loss_train', loss_train) tf.summary.scalar('lr', optimizer.lr(step_counter)) # End of an epoch. c_epoch += 1 current_epoch.assign(c_epoch) # Save every n OR after last epoch. if (tf.equal((current_epoch - 1) % checkpoint_every_n_epoch, 0) or c_epoch > epochs): logging.info('Checkpoint after epoch: %d', c_epoch - 1) checkpoint_manager.save(checkpoint_number=c_epoch - 1) test_loss, test_acc, n_samples = cross_entropy_loss( model, dataset_test, calculate_accuracy=True) tf.summary.scalar('test_loss_all', test_loss) tf.summary.scalar('test_acc_all', test_acc) logging.info('Overall_test_loss:%.4f, Overall_test_acc:%.4f, n_samples:%d', test_loss, test_acc, n_samples)
def prune_and_finetune_model(dataset_name='imagenet_vgg', flop_regularizer=0, n_units_target=4000, checkpoint_interval=5, log_interval=5, n_finetune=100, lr=1e-4, momentum=0.9, seed=8): """Trains the model with regular logging and checkpoints. Args: dataset_name: str, dataset to train on. flop_regularizer: float, multiplier for the flop regularization. If 0, no regularization is made during pruning. n_units_target: int, number of unit to prune. checkpoint_interval: int, number of epochs between two checkpoints. log_interval: int, number of steps between two logging events. n_finetune: int, number of steps between two pruning steps. Starting from the first iteration we prune 1 unit every `n_finetune` gradient update. lr: float, learning rate for the fine-tuning steps. momentum: float, momentum multiplier for the fine-tuning steps. seed: int, random seed to be set to produce reproducible experiments. Raises: ValueError: when the n_finetune is not positive. """ if n_finetune <= 0: raise ValueError('n_finetune must be positive: given %d' % n_finetune) tf.random.set_seed(seed) optimizer = tf.keras.optimizers.SGD(lr, momentum=momentum) # imagenet_vgg->imagenet dataset_basename = dataset_name.split('_')[0] model = model_load.get_model(dataset_name=dataset_basename) # Uncomment following if your model is (defunable). # model.call = tf.function(model.call) datasets = data.get_datasets(dataset_name=dataset_name) (dataset_train, _, subset_val, subset_test, subset_val2) = datasets logging.info('Model init-config: %s', model.init_config) logging.info('Model forward chain: %s', str(model.forward_chain)) unit_pruner = pruner.UnitPruner(model, subset_val) # We prune all conv layers. pruning_pool = _all_vgg_conv_layers baselines = { l_name: c_flop * flop_regularizer for l_name, c_flop in zip(pruning_pool, _vgg16_flop_regularizition) } step_counter = optimizer.iterations tf.summary.experimental.set_step(step_counter) c_pruning_step = tf.Variable(1) # Create checkpoint object TODO check whether you need ckpt-prefix. checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, c_pruning_step=c_pruning_step) # No limits basically by setting to `n_units_target`. checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory=FLAGS.outdir, max_to_keep=None) latest_cpkt = checkpoint_manager.latest_checkpoint if latest_cpkt: logging.info('Using latest checkpoint at %s', latest_cpkt) # Restore variables on creation if a checkpoint exists. checkpoint.restore(latest_cpkt) logging.info('Resuming with pruning step: %d', c_pruning_step.numpy()) pruning_step = c_pruning_step.numpy() while pruning_step <= n_units_target: for (x, y) in dataset_train: # Every `n_finetune` step perform pruning. if step_counter.numpy() % n_finetune == 0: if pruning_step > n_units_target: # This holds true when we prune last time and fine tune N many # iterations. We would break and the while loop above would break, # too. break tf.logging.info('Pruning Step:%d', pruning_step) start = time.time() unit_pruner.prune_one_unit(pruning_pool, baselines=baselines) end = time.time() tf.logging.info( '\nTrain time for Pruning Step #%d (step %d): %f', pruning_step, tf.train.get_or_create_global_step().numpy(), end - start) pruning_step += 1 c_pruning_step.assign(pruning_step) if tf.equal((pruning_step - 1) % checkpoint_interval, 0): checkpoint_manager.save() if step_counter.numpy() % log_interval == 0: train_utils.log_loss_acc(model, subset_val2, subset_test) train_utils.log_sparsity(model) with tf.GradientTape() as tape: loss_train, _, _ = cross_entropy_loss(model, (x, y), training=True) grads = tape.gradient(loss_train, model.variables) # Updating the model. optimizer.apply_gradients(list(zip(grads, model.variables)), global_step=step_counter) if step_counter.numpy() % log_interval == 0: tf.summary.scalar('loss_train', loss_train) tf.summary.image('x', x, max_outputs=1)