def testDefaultBatch(self): # Model returns logits for 3 samples and 5 classes. n_sample, chunk_size, n_out = 42, 10, 10 model = self._create_mock_model(n_out=n_out) x_all = tf.ones((n_sample, 2)) y_all = tf.ones((n_sample,), dtype=tf.int32) d = tf.data.Dataset.from_tensor_slices((x_all, y_all)) loss, acc, total_samples = train_utils.cross_entropy_loss( model, d.batch(chunk_size)) self.assertEqual(total_samples, n_sample) self.assertIsNone(acc) all_logits = [] for i in range(0, n_sample, chunk_size): c_size = min(n_sample, i + chunk_size) - i all_logits.append(self.get_logits(c_size, n_out)) logits = tf.concat(all_logits, 0) cce = tf.keras.losses.SparseCategoricalCrossentropy() true_loss = cce(y_all, logits) self.assertAllClose(loss, true_loss, atol=1e-4) self.assertEqual(model.call_count, 5) model.get_layer_keys.assert_not_called() train_utils.cross_entropy_loss(model, d.batch(chunk_size), aggregate_values=True) model.get_layer_keys.assert_called_once() model.conv_1.reset_saved_values.assert_called_once() model.conv_2.reset_saved_values.assert_called_once()
def get_pruning_measurements(model, subset_val, layers2prune): """Returns 6 different pruning scores and some other measurements. Args: model: tf.keras.Model, model to be pruned. subset_val: tf.data.Dataset, to calculate data-dependent scoring functions. layers2prune: list<str>, of layers which will be scored. Returns: dict, scores with 6 keys `mrs`, `abs_mrs`, `rs`, `abs_rs`, `norm`, `rand`. and score dictionary values. Each Score dictionary has scorings for each layer provided in `layers2prune`. dict, mean unit activations for each layer in layers2prune. dict, l2norms, average square activations for each unit. see `deadunits.layers.TaylorScorer` for further details. """ # Run once and get mrs, rs and mean calculated train_utils.cross_entropy_loss(model, subset_val, training=False, compute_mean_replacement_saliency=True, compute_removal_saliency=True, is_abs=True, aggregate_values=True, run_gradient=True) scores = collections.defaultdict(dict) mean_values = {} l2_norms = {} for l_name in layers2prune: l_ts = getattr(model, l_name + '_ts') scores['abs_mrs'][l_name] = l_ts.get_saved_values('mrs') scores['abs_rs'][l_name] = l_ts.get_saved_values('rs') mean_values[l_name] = l_ts.get_saved_values('mean') l2_norms[l_name] = l_ts.get_saved_values('l2norm') # Run again to get without abs. train_utils.cross_entropy_loss(model, subset_val, training=False, compute_mean_replacement_saliency=True, compute_removal_saliency=True, is_abs=False, aggregate_values=True, run_gradient=True) for l_name in layers2prune: l_ts = getattr(model, l_name + '_ts') l = getattr(model, l_name) scores['mrs'][l_name] = l_ts.get_saved_values('mrs') scores['rs'][l_name] = l_ts.get_saved_values('rs') # The reson we calculate them here to reduce the amount of the code. # `weights[0]` is the weight, where `weights[1]` is the bias. scores['norm'][l_name] = unitscorers.norm_score( l.get_layer().weights[0]) scores['rand'][l_name] = unitscorers.random_score( l.get_layer().weights[0]) return (scores, mean_values, l2_norms)
def testSingleBatch(self): # Model returns logits for 3 samples and 5 classes. n_sample, n_out = 8, 10 model = self._create_mock_model(n_out=n_out) x = tf.ones((n_sample, 2)) y = tf.ones((n_sample,), dtype=tf.int32) loss, acc, total_samples = train_utils.cross_entropy_loss( model, (x, y), calculate_accuracy=True) model2 = self._create_mock_model(n_out=n_out) d = tf.data.Dataset.from_tensor_slices((x, y)) loss2, acc2, total_samples2 = train_utils.cross_entropy_loss( model2, d.batch(n_sample), calculate_accuracy=True) self.assertEqual(total_samples2, total_samples) self.assertAllClose(acc, acc2) self.assertAllClose(loss, loss2)
def testDefaultAccuracy(self): # Model returns logits for 3 samples and 5 classes. n_sample, n_out = 8, 10 model = self._create_mock_model(n_out=n_out) x = tf.ones((n_sample, 2)) y = tf.ones((n_sample,), dtype=tf.int32) _, acc, total_samples = train_utils.cross_entropy_loss( model, (x, y), calculate_accuracy=True) self.assertEqual(total_samples, n_sample) logits = self.get_logits(n_sample, n_out) predictions = tf.cast(tf.argmax(logits, 1), y.dtype) acc_obj = tf.keras.metrics.Accuracy() acc_obj.update_state(tf.squeeze(y), predictions) true_acc = acc_obj.result().numpy() self.assertAllClose(acc, true_acc)
def testDefaultSingleBatch(self): # Model returns logits for 3 samples and 5 classes. n_sample, n_out = 8, 10 model = self._create_mock_model(n_out=n_out) x = tf.ones((n_sample, 2)) y = tf.ones((n_sample,), dtype=tf.int32) loss, acc, total_samples = train_utils.cross_entropy_loss(model, (x, y)) self.assertEqual(total_samples, n_sample) self.assertIsNone(acc) logits = self.get_logits(n_sample, n_out) cce = tf.keras.losses.SparseCategoricalCrossentropy() true_loss = cce(y, logits) self.assertAllClose(loss, true_loss) model.assert_called_once_with( x, training=False, compute_mean_replacement_saliency=False, compute_removal_saliency=False, is_abs=True, aggregate_values=False)
def prune_layer_and_eval(dataset_name='cifar10', pruning_methods=(('rand', True), ), model_dir=gin.REQUIRED, l_name=gin.REQUIRED, n_exps=gin.REQUIRED, val_size=gin.REQUIRED, max_pruning_fraction=gin.REQUIRED): """Loads and prunes a model with various sparsity targets. This function assumes that the `seed_dir` exists Args: dataset_name: str, either 'cifar10' or 'imagenet'. pruning_methods: iterator of tuples, (scoring, is_bp) where `is_bp` is a boolean indicating the usage of Mean Replacement Pruning. Scoring is a string from ['norm', 'mrs', 'rs', 'rand', 'abs_mrs', 'rs']. 'norm': unitscorers.norm_score '{abs_}mrs': `compute_mean_replacement_saliency=True` for `TaylorScorer` layer. if {abs_} prefix exists absolute value used before aggregation. i.e. is_abs=True. '{abs_}rs': `compute_removal_saliency=True` for `TaylorScorer` layer. if {abs_} prefix exists absolute value used before aggregation. i.e. is_abs=True. 'rand': unitscorers.random_score is_bp; bool, if True, mean value of the units are propagated to the next layer prior to the pruning. `bp` for bias_propagation. model_dir: str, Path to the checkpoint directory. l_name: str, a valid layer name from the model loaded. n_exps: int, number of pruning experiments to be made. This number is used generate pruning counts for different experiments. val_size: int, size for the first dataset, passed to the `get_datasets`. max_pruning_fraction: float, max sparsity for pruning. Multiplying this number with the total number of units, we would get the upper limit for the pruning_count. Raises: AssertionError: when no checkpoint is found. ValueError: when the scoring function key is not valid. OSError: when there is no checkpoint found. """ logging.info('Looking checkpoint at: %s', model_dir) latest_cpkt = tf.train.latest_checkpoint(model_dir) if not latest_cpkt: raise OSError('No checkpoint found in %s' % model_dir) logging.info('Using latest checkpoint at %s', latest_cpkt) model = model_load.get_model(load_path=latest_cpkt, dataset_name=dataset_name) datasets = data.get_datasets(dataset_name=dataset_name, val_size=val_size) _, _, subset_val, subset_test, subset_val2 = datasets input_shapes = {l_name: getattr(model, l_name + '_ts').xshape} layers2prune = [l_name] measurements = pruner.get_pruning_measurements(model, subset_val, layers2prune) (all_scores, mean_values, _) = measurements for scoring, is_bp in pruning_methods: if scoring not in pruner.ALL_SCORING_FUNCTIONS: raise ValueError('%s is not one of %s' % (scoring, pruner.ALL_SCORING_FUNCTIONS)) scores = all_scores[scoring] d_path = os.path.join( FLAGS.outdir, '%d-%s-%s-%s.pickle' % (val_size, l_name, scoring, str(is_bp))) logging.info(d_path) if tf.gfile.Exists(d_path): logging.warning('File %s exists, skipping.', d_path) else: ls_train_loss = [] ls_train_acc = [] ls_test_loss = [] ls_test_acc = [] n_units = input_shapes[l_name][-1].value n_unit_pruned_max = n_units * max_pruning_fraction c_slice = np.linspace(0, n_unit_pruned_max, n_exps, dtype=np.int32) logging.info('Layer:%s, n_units:%d, c_slice:%s', l_name, n_units, str(c_slice)) for pruning_count in c_slice: # Cast from np.int32 to int. pruning_count = int(pruning_count) copied_model = model.clone() pruning_factor = None pruner.prune_model_with_scores(copied_model, scores, is_bp, layers2prune, pruning_factor, pruning_count, mean_values, input_shapes) test_loss, test_acc, _ = train_utils.cross_entropy_loss( copied_model, subset_test, calculate_accuracy=True) train_loss, train_acc, _ = train_utils.cross_entropy_loss( copied_model, subset_val2, calculate_accuracy=True) logging.info('is_bp: %s, n: %d, test_loss%f, train_loss:%f', str(is_bp), pruning_count, test_loss, train_loss) ls_train_loss.append(train_loss.numpy()) ls_test_loss.append(test_loss.numpy()) ls_test_acc.append(test_acc.numpy()) ls_train_acc.append(train_acc.numpy()) utils.pickle_object( (ls_train_loss, ls_train_acc, ls_test_loss, ls_test_acc), d_path)
def prune_and_finetune_model(pruning_schedule=gin.REQUIRED, model_dir=gin.REQUIRED, dataset_name='cifar10', checkpoint_interval=5, log_interval=5, n_finetune=100, epochs=20, lr=1e-4, momentum=0.9): """Loads and prunes the model layer by layer according to the schedule. The model is finetuned between pruning tasks (as given in the schedule). Args: pruning_schedule: list<str, float>, where the str is a valid layer name and the float is the pruning fraction of that layer. Layers are pruned in the order they are given and `n_finetune` steps taken in between. model_dir: str, Path to the checkpoint directory. dataset_name: str, either 'cifar10' or 'imagenet'. checkpoint_interval: int, number of epochs between two checkpoints. log_interval: int, number of steps between two logging events. n_finetune: int, number of steps between two pruning steps. epochs: int, total number of epochs to run. lr: float, learning rate for the fine-tuning steps. momentum: float, momentum multiplier for the fine-tuning steps. Raises: ValueError: when the n_finetune is not positive. OSError: when there is no checkpoint found. """ if n_finetune <= 0: raise ValueError('n_finetune must be positive: given %d' % n_finetune) optimizer = tf.keras.optimizers.SGD(learning_rate=lr, momentum=momentum) logging.info('Using outdir: %s', model_dir) latest_cpkt = tf.train.latest_checkpoint(model_dir) if not latest_cpkt: raise OSError('No checkpoint found in %s' % model_dir) logging.info('Using latest checkpoint at %s', latest_cpkt) model = model_load.get_model(load_path=latest_cpkt, dataset_name=dataset_name) logging.info('Model init-config: %s', model.init_config) logging.info('Model forward chain: %s', str(model.forward_chain)) datasets = data.get_datasets(dataset_name=dataset_name) dataset_train, dataset_test, subset_val, subset_test, subset_val2 = datasets unit_pruner = pruner.UnitPruner(model, subset_val) step_counter = optimizer.iteration tf.summary.experimental.set_step(step_counter) current_epoch = tf.Variable(1) current_layer_index = tf.Variable(0) checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, current_epoch=current_epoch, current_layer_index=current_layer_index) latest_cpkt = tf.train.latest_checkpoint(FLAGS.outdir) if latest_cpkt: logging.info('Using latest checkpoint at %s', latest_cpkt) # Restore variables on creation if a checkpoint exists. checkpoint.restore(latest_cpkt) logging.info('Resuming with epoch: %d', current_epoch.numpy()) c_epoch = current_epoch.numpy() c_layer_index = current_layer_index.numpy() # Starting from the first batch, we perform pruning every `n_finetune` step. # Layers pruned one by one according to the pruning schedule given. while c_epoch <= epochs: logging.info('Starting Epoch: %d', c_epoch) for (x, y) in dataset_train: # Every `n_finetune` step perform pruning. if (step_counter.numpy() % n_finetune == 0 and c_layer_index < len(pruning_schedule)): logging.info('Pruning at iteration: %d', step_counter.numpy()) l_name, pruning_factor = pruning_schedule[c_layer_index] unit_pruner.prune_layer(l_name, pruning_factor=pruning_factor) train_utils.log_loss_acc(model, subset_val2, subset_test) train_utils.log_sparsity(model) # Re-init optimizer and therefore remove previous momentum. optimizer = tf.keras.optimizers.SGD(learning_rate=lr, momentum=momentum) c_layer_index += 1 current_layer_index.assign(c_layer_index) else: if step_counter.numpy() % log_interval == 0: logging.info('Iteration: %d', step_counter.numpy()) train_utils.log_loss_acc(model, subset_val2, subset_test) train_utils.log_sparsity(model) with tf.GradientTape() as tape: loss_train, _, _ = cross_entropy_loss(model, (x, y), training=True) grads = tape.gradient(loss_train, model.variables) # Updating the model. optimizer.apply_gradients(list(zip(grads, model.variables)), global_step=step_counter) if step_counter.numpy() % log_interval == 0: tf.summary.scalar('loss_train', loss_train) tf.summary.image('x', x, max_outputs=1) # End of an epoch. c_epoch += 1 current_epoch.assign(c_epoch) # Save every n OR after last epoch. if (tf.equal((current_epoch - 1) % checkpoint_interval, 0) or c_epoch > epochs): # Re-init checkpoint to ensure the masks are captured. The reason for # this is that the masks are initially not generated. checkpoint = tf.train.Checkpoint( optimizer=optimizer, model=model, current_epoch=current_epoch, current_layer_index=current_layer_index) logging.info('Checkpoint after epoch: %d', c_epoch - 1) checkpoint.save( os.path.join(FLAGS.outdir, 'ckpt-%d' % (c_epoch - 1))) # Test model test_loss, test_acc, n_samples = cross_entropy_loss( model, dataset_test, calculate_accuracy=True) tf.summary.scalar('test_loss_all', test_loss) tf.summary.scalar('test_acc_all', test_acc) logging.info( 'Overall_test_loss: %.4f, Overall_test_acc: %.4f, n_samples: %d', test_loss, test_acc, n_samples)
def train_model(dataset_name='cifar10', checkpoint_every_n_epoch=5, log_interval=1000, epochs=10, lr=1e-2, lr_drop_iter=1500, lr_decay=0.5, momentum=0.9, seed=8): """Trains the model with regular logging and checkpoints. Args: dataset_name: str, either 'cifar10' or 'imagenet'. checkpoint_every_n_epoch: int, number of epochs between two checkpoints. log_interval: int, number of steps between two logging events. epochs: int, epoch to train with. lr: float, learning rate for the fine-tuning steps. lr_drop_iter: int, iteration between two consequtive lr drop. lr_decay: float: multiplier for step learning rate reduction. momentum: float, momentum multiplier for the fine-tuning steps. seed: int, random seed to be set to produce reproducible experiments. Raises: AssertionError: when the args doesn't match the specs. """ assert dataset_name in ['cifar10', 'imagenet', 'cub200', 'imagenet_vgg'] tf.random.set_seed(seed) # The model is configured through gin parameters. model = model_load.get_model(dataset_name=dataset_name) # model.call = tf.contrib.eager.defun(model.call) datasets = data.get_datasets(dataset_name=dataset_name) dataset_train, dataset_test, _, subset_test, subset_val = datasets lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( lr, decay_steps=lr_drop_iter, decay_rate=lr_decay, staircase=True) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=momentum) step_counter = optimizer.iterations tf.summary.experimental.set_step(step_counter) logging.info('Model init-config: %s', model.init_config) logging.info('Model forward chain: %s', str(model.forward_chain)) current_epoch = tf.Variable(1) # Create checkpoint object TODO check whether you need ckpt-prefix. checkpoint = tf.train.Checkpoint( optimizer=optimizer, model=model, current_epoch=current_epoch) # No limits basically by setting to `n_units_target`. checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=FLAGS.outdir, max_to_keep=None) latest_cpkt = checkpoint_manager.latest_checkpoint if latest_cpkt: logging.info('Using latest checkpoint at %s', latest_cpkt) # Restore variables on creation if a checkpoint exists. checkpoint.restore(latest_cpkt) logging.info('Resuming with epoch: %d', current_epoch.numpy()) c_epoch = current_epoch.numpy() while c_epoch <= epochs: logging.info('Starting Epoch:%d', c_epoch) for (x, y) in dataset_train: if step_counter % log_interval == 0: train_utils.log_loss_acc(model, subset_val, subset_test) tf.summary.image('x', x, max_outputs=1) logging.info('Iteration:%d', step_counter.numpy()) with tf.GradientTape() as tape: loss_train, _, _ = cross_entropy_loss(model, (x, y), training=True) grads = tape.gradient(loss_train, model.variables) # Updating the model. optimizer.apply_gradients(zip(grads, model.variables)) tf.summary.scalar('loss_train', loss_train) tf.summary.scalar('lr', optimizer.lr(step_counter)) # End of an epoch. c_epoch += 1 current_epoch.assign(c_epoch) # Save every n OR after last epoch. if (tf.equal((current_epoch - 1) % checkpoint_every_n_epoch, 0) or c_epoch > epochs): logging.info('Checkpoint after epoch: %d', c_epoch - 1) checkpoint_manager.save(checkpoint_number=c_epoch - 1) test_loss, test_acc, n_samples = cross_entropy_loss( model, dataset_test, calculate_accuracy=True) tf.summary.scalar('test_loss_all', test_loss) tf.summary.scalar('test_acc_all', test_acc) logging.info('Overall_test_loss:%.4f, Overall_test_acc:%.4f, n_samples:%d', test_loss, test_acc, n_samples)
def prune_and_finetune_model(dataset_name='imagenet_vgg', flop_regularizer=0, n_units_target=4000, checkpoint_interval=5, log_interval=5, n_finetune=100, lr=1e-4, momentum=0.9, seed=8): """Trains the model with regular logging and checkpoints. Args: dataset_name: str, dataset to train on. flop_regularizer: float, multiplier for the flop regularization. If 0, no regularization is made during pruning. n_units_target: int, number of unit to prune. checkpoint_interval: int, number of epochs between two checkpoints. log_interval: int, number of steps between two logging events. n_finetune: int, number of steps between two pruning steps. Starting from the first iteration we prune 1 unit every `n_finetune` gradient update. lr: float, learning rate for the fine-tuning steps. momentum: float, momentum multiplier for the fine-tuning steps. seed: int, random seed to be set to produce reproducible experiments. Raises: ValueError: when the n_finetune is not positive. """ if n_finetune <= 0: raise ValueError('n_finetune must be positive: given %d' % n_finetune) tf.random.set_seed(seed) optimizer = tf.keras.optimizers.SGD(lr, momentum=momentum) # imagenet_vgg->imagenet dataset_basename = dataset_name.split('_')[0] model = model_load.get_model(dataset_name=dataset_basename) # Uncomment following if your model is (defunable). # model.call = tf.function(model.call) datasets = data.get_datasets(dataset_name=dataset_name) (dataset_train, _, subset_val, subset_test, subset_val2) = datasets logging.info('Model init-config: %s', model.init_config) logging.info('Model forward chain: %s', str(model.forward_chain)) unit_pruner = pruner.UnitPruner(model, subset_val) # We prune all conv layers. pruning_pool = _all_vgg_conv_layers baselines = { l_name: c_flop * flop_regularizer for l_name, c_flop in zip(pruning_pool, _vgg16_flop_regularizition) } step_counter = optimizer.iterations tf.summary.experimental.set_step(step_counter) c_pruning_step = tf.Variable(1) # Create checkpoint object TODO check whether you need ckpt-prefix. checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, c_pruning_step=c_pruning_step) # No limits basically by setting to `n_units_target`. checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory=FLAGS.outdir, max_to_keep=None) latest_cpkt = checkpoint_manager.latest_checkpoint if latest_cpkt: logging.info('Using latest checkpoint at %s', latest_cpkt) # Restore variables on creation if a checkpoint exists. checkpoint.restore(latest_cpkt) logging.info('Resuming with pruning step: %d', c_pruning_step.numpy()) pruning_step = c_pruning_step.numpy() while pruning_step <= n_units_target: for (x, y) in dataset_train: # Every `n_finetune` step perform pruning. if step_counter.numpy() % n_finetune == 0: if pruning_step > n_units_target: # This holds true when we prune last time and fine tune N many # iterations. We would break and the while loop above would break, # too. break tf.logging.info('Pruning Step:%d', pruning_step) start = time.time() unit_pruner.prune_one_unit(pruning_pool, baselines=baselines) end = time.time() tf.logging.info( '\nTrain time for Pruning Step #%d (step %d): %f', pruning_step, tf.train.get_or_create_global_step().numpy(), end - start) pruning_step += 1 c_pruning_step.assign(pruning_step) if tf.equal((pruning_step - 1) % checkpoint_interval, 0): checkpoint_manager.save() if step_counter.numpy() % log_interval == 0: train_utils.log_loss_acc(model, subset_val2, subset_test) train_utils.log_sparsity(model) with tf.GradientTape() as tape: loss_train, _, _ = cross_entropy_loss(model, (x, y), training=True) grads = tape.gradient(loss_train, model.variables) # Updating the model. optimizer.apply_gradients(list(zip(grads, model.variables)), global_step=step_counter) if step_counter.numpy() % log_interval == 0: tf.summary.scalar('loss_train', loss_train) tf.summary.image('x', x, max_outputs=1)
def prune_layer(self, layer_name, pruning_factor=0.1, pruning_count=None, pruning_method=None, is_bp=None): """Prunes a single layer using the given scoring function. Args: layer_name: str, layer name to prune. pruning_factor: float, 0 < pruning_factor < 1. pruning_count: int, if not None, sets the pruning_factor to None. This is because you can either prune a fraction or a number of units. pruning_count is used to determine how many units to prune per layer. pruning_method: str, from ['norm', 'mrs', 'rs', 'rand', 'abs_mrs', 'rs']. If given, overwrites the default value. is_bp: bool, if True Mean Replacement Pruning is used and bias propagation is made. If given, overwrites the default value. Raises: ValueError: if the arguments provided doesn't match specs. """ pruning_method = pruning_method if pruning_method else self.pruning_method is_bp = is_bp if is_bp else self.is_bp if pruning_method not in ALL_SCORING_FUNCTIONS: raise ValueError('%s is not one of %s' % (pruning_method, ALL_SCORING_FUNCTIONS)) # Need to wrap up the layer in a list for # `pruner.prune_model_with_scores()` call. layers2prune = [layer_name] # If pruning_count exists invalidate the pruning_factor. if pruning_count is not None: if not isinstance(pruning_count, int): raise ValueError('pruning_count: %s should be an int' % pruning_count) elif pruning_count < 1: raise ValueError( 'pruning_count: %d should be greater than 1.' % pruning_count) pruning_factor = None # Validate pruning_factor. elif pruning_factor == 0: return elif pruning_factor <= 0 or pruning_factor >= 1: raise ValueError('pruning_factor: %s should be in (0, 1)' % pruning_factor) logging.info('Pruning layer `%s` with: %s, f:%.2f', layer_name, pruning_method, pruning_factor) input_shapes = { layer_name: getattr(self.model, layer_name + '_ts').xshape } # Calculating the scoring function/mean value. is_abs = pruning_method.startswith('abs') is_mrs = pruning_method.endswith('mrs') is_rs = pruning_method.endswith('rs') and not is_mrs train_utils.cross_entropy_loss( self.model, self.subset_val, training=False, compute_mean_replacement_saliency=is_mrs, compute_removal_saliency=is_rs, is_abs=is_abs, aggregate_values=True, run_gradient=True) scores = {} mean_values = {} # `layer_ts` stands for TaylorScorer layer. layer_ts = getattr(self.model, layer_name + '_ts') masked_layer = getattr(self.model, layer_name) masked_layer.apply_masks() if pruning_method == 'rand': scores[layer_name] = unitscorers.random_score( masked_layer.get_layer().weights[0]) elif pruning_method == 'norm': scores[layer_name] = unitscorers.norm_score( masked_layer.get_layer().weights[0]) else: # mrs or rs. score_name = 'rs' if is_rs else 'mrs' scores[layer_name] = layer_ts.get_saved_values(score_name) mean_values[layer_name] = layer_ts.get_saved_values('mean') prune_model_with_scores(self.model, scores, is_bp, layers2prune, pruning_factor, pruning_count, mean_values, input_shapes)
def probe_pruning(model, subset_val, subset_val2, subset_test, f_retrain, baselines=(0.0, 0.0), layers2prune='all', n_retrain=0, pruning_factor=0.1, pruning_count=None, pruning_methods=(('norm', True), )): """Prunes a copy of the network and calculates change in the loss. By default calculates mrs,rs and mean values. Args: model: tf.keras.Model subset_val: tf.data.Dataset, used for loss calculation. subset_val2: tf.data.Dataset, used for pruning scoring. subset_test: tf.data.Dataset, from test set. f_retrain: function, used for retraining with 2 arguments `copied_model` and `n_retrain`. baselines: tuple, <val_loss, test_loss> Baselines to subtract from loss, layers2prune: list or str, each elemenet `name` in the list should be a valid MasketLayer under model. model.name->MaskedLayer. One can also provide following tokens: `all`: searches model finds all MaskedLayers's and prunes them all. `firstconv`: prunes the first conv_layer in the `forward_chain`. `midconv`: prunes the `mid=n_conv//2` conv_layer in the `forward_chain`. `lastconv`: prunes the last conv_layer in the `forward_chain`. `firstdense`: prunes the first dense layer in the `forward_chain`. n_retrain: int, Number of retraining updates to perform after pruning. If n_retrain<=0, then nothing happens. pruning_factor: float, 0<pruning_factor<1 pruning_count: int, if not None, sets the pruning_factor to None. This is because you can either prune a fraction or a number of units. pruning_count is used to determine how many units to prune per layer. pruning_methods: iterator of tuples, (scoring, is_bp) where `is_bp` is a boolean indicating the usage of Mean Replacement Pruning. Scoring is a string from ['norm', 'mrs', 'rs', 'rand', 'abs_mrs', 'rs']. 'norm': unitscorers.norm_score '{abs_}mrs': `compute_mean_replacement_saliency=True` for `TaylorScorer` layer. if {abs_} prefix exists absolute value used before aggregation. i.e. is_abs=True. '{abs_}rs': `compute_removal_saliency=True` for `TaylorScorer` layer. if {abs_} prefix exists absolute value used before aggregation. i.e. is_abs=True. 'rand': unitscorers.random_score is_bp; bool, if True, mean value of the units are propagated to the next layer prior to the pruning. `bp` for bias_propagation. Raises: AssertionError: if the arguments provided doesn't match specs. Returns: selected_units: dict, keys coming from `layers2prune` and each value is a tuple of (score, mask), where score is the pruning scores for each units and mask is the corresponding binary masks created for the given fraction. mean_values: dict, keys coming from `layers2prune` and each value is the mean activation under training batch. """ # Check validity of `layers2prune` and process. layers2prune = process_layers2prune(layers2prune, model) # If pruning_count exists invalidate the pruning_factor if pruning_count is not None: assert isinstance(pruning_count, int) and pruning_count >= 1 pruning_factor = None copied_model = model.clone() loss_val, loss_test = baselines selected_units = {} input_shapes = { l_name: getattr(model, l_name + '_ts').xshape for l_name in layers2prune } measurements = get_pruning_measurements(copied_model, subset_val2, layers2prune) (scores, mean_values, l2_norms) = measurements for scoring, is_bp in pruning_methods: assert (scoring in ['norm', 'abs_mrs', 'abs_rs', 'mrs', 'rs', 'rand']) copied_model = model.clone() scalar_summary_tag = 'pruning_penalty%s_%s' % ('_bp' if is_bp else '', scoring) logging.info('Pruning following layers: %s, Using %s %s bias_prop.', layers2prune, scoring, 'with' if is_bp else 'without') selected_units_scoring = prune_model_with_scores( copied_model, scores[scoring], is_bp, layers2prune, pruning_factor, pruning_count, mean_values, input_shapes) # There's going to be two pass for a given `scoring` with and without `bp`. # We will record them only once, since they are equal. if scoring not in selected_units: selected_units[scoring] = selected_units_scoring if n_retrain > 0: f_retrain(copied_model, n_retrain) # Setting training=True since test uses running average and revives pruned, # units. loss_new, _, _ = train_utils.cross_entropy_loss(copied_model, subset_val, training=True) tf.summary.scalar(scalar_summary_tag, loss_new - loss_val) # Setting training=True, otherwise BatchNorm uses the accumulated mean and # std during forward propagation, which causes pruned units to generate # non-zero constants. loss_new, _, _ = train_utils.cross_entropy_loss(copied_model, subset_test, training=True) tf.summary.scalar(scalar_summary_tag + '_test', loss_new - loss_test) return selected_units, mean_values, l2_norms
def prune_one_unit(self, pruning_pool, baselines=None, normalized_scores=True, pruning_method=None, is_bp=None): """Picks a layer and prunes a single unit using the scoring function. Args: pruning_pool: list, of layers that are considered for pruning. baselines: dict, if exists, subtracts the given constant from the scores of individual layers. The keys should a subset of pruning_pool. normalized_scores: bool, if True the scores are normalized with l2 norm. pruning_method: str, from ['norm', 'mrs', 'rs', 'rand', 'abs_mrs', 'rs']. If given, overwrites the default value. is_bp: bool, if True Mean Replacement Pruning is used and bias propagation is made. If given, overwrites the default value. Raises: AssertionError: if the arguments provided doesn't match specs. """ pruning_method = pruning_method if pruning_method else self.pruning_method is_bp = is_bp if is_bp else self.is_bp if pruning_method not in ALL_SCORING_FUNCTIONS: raise ValueError('%s is not one of %s' % (pruning_method, ALL_SCORING_FUNCTIONS)) if baselines is None: baselines = {} logging.info('Prunning with: %s, is_bp: %s', pruning_method, is_bp) # Calculating the scoring function/mean value. is_abs = pruning_method.startswith('abs') is_mrs = pruning_method.endswith('mrs') is_rs = pruning_method.endswith('rs') and not is_mrs is_grad = is_mrs or is_rs train_utils.cross_entropy_loss( self.model, self.subset_val, training=False, compute_mean_replacement_saliency=is_mrs, compute_removal_saliency=is_rs, is_abs=is_abs, aggregate_values=True, run_gradient=is_grad) scores = {} mean_values = {} smallest_score = None smallest_l_name = None smallest_nprune = None for l_name in pruning_pool: l_ts = getattr(self.model, l_name + '_ts') l = getattr(self.model, l_name) mean_values[l_name] = l_ts.get_saved_values('mean') # Make sure the masks are applied after last gradient update. Note # that this is necessary for `norm` functions, since it doesn't call the # model and therefore the masks are not applied. l.apply_masks() if pruning_method == 'rand': scores[l_name] = unitscorers.random_score( l.get_layer().weights[0]) elif pruning_method == 'norm': scores[l_name] = unitscorers.norm_score( l.get_layer().weights[0]) else: # mrs or rs. score_name = 'rs' if is_rs else 'mrs' scores[l_name] = l_ts.get_saved_values(score_name) if normalized_scores: scores[l_name] /= tf.norm(scores[l_name]) baseline_score = baselines.get(l_name, 0) if baseline_score != 0: # Regularizing the scores with c_flop weights. scores[l_name] -= baseline_score # If there is an existing mask we have to make sure pruned connections # are indicated. Let's set them to very small negative number (-1e10). # Note that the elements of `l.mask_bias` consist of zeros and ones only. if l.mask_bias is not None: # Setting the scores of the pruned units to zero. scores[l_name] = scores[l_name] * l.mask_bias # Setting the scores of the pruned units to -1e10. scores[l_name] += -1e10 * (1 - l.mask_bias) # Number of previously pruned units. n_pruned = tf.count_nonzero(l.mask_bias - 1).numpy() layer_smallest_score = tf.reduce_min( tf.boolean_mask(scores[l_name], l.mask_bias)).numpy() # Do not prune the last unit. if tf.equal(n_pruned + 1, tf.size(l.mask_bias)): continue else: n_pruned = 0 layer_smallest_score = tf.reduce_min(scores[l_name]).numpy() logging.info('Layer:%s, min:%f', l_name, layer_smallest_score) if smallest_score is None or (layer_smallest_score < smallest_score): smallest_score = layer_smallest_score smallest_l_name = l_name # We want to prune one more than before. smallest_nprune = n_pruned + 1 logging.info('UNIT_PRUNED, layer:%s, n_pruned:%d', smallest_l_name, smallest_nprune) mean_values = {smallest_l_name: mean_values[smallest_l_name]} scores = {smallest_l_name: scores[smallest_l_name]} input_shapes = { smallest_l_name: getattr(self.model, smallest_l_name + '_ts').xshape } layers2prune = [smallest_l_name] prune_model_with_scores(self.model, scores, is_bp, layers2prune, None, smallest_nprune, mean_values, input_shapes)