def testSaveAndLoad(self): model_arch_name = 'small_conv' dataset_name = 'cifar10' # Let's create and save model. model = model_load.get_model(model_arch_name=model_arch_name, dataset_name=dataset_name) load_path = self._create_and_save_model(model) model_loaded = model_load.get_model(model_arch_name=model_arch_name, dataset_name=dataset_name, load_path=load_path) # If loaded correctly all should be equal. self.assertAllClose(model.conv_1.weights[0].numpy(), model_loaded.conv_1.weights[0].numpy())
def testInit(self, model_arch_name, dataset_name, n_classes): model = model_load.get_model(model_arch_name=model_arch_name, dataset_name=dataset_name) # Check output layer has correct number of units. self.assertEqual(model.output_1.units, n_classes) # Lets check the number of channels in the first conv layer. arch_definition = getattr(model_defs, model_arch_name) n_units_in_first_layer = arch_definition[0][1] self.assertEqual(model.conv_1.filters, n_units_in_first_layer)
def testPreparePruning(self, is_prepared): # Arrange model_arch_name = 'small_conv' dataset_name = 'cifar10' # Let's create and save model. model = model_load.get_model(model_arch_name=model_arch_name, dataset_name=dataset_name) load_path = self._create_and_save_model(model) # Act model_loaded = model_load.get_model(model_arch_name=model_arch_name, dataset_name=dataset_name, load_path=load_path, prepare_for_pruning=is_prepared) # The architecture defined in `model_defs.small_conv` would generate the # 3 conv layers and 2 dense layers with following attribute names. all_layers = ['conv_1', 'conv_2', 'conv_3', 'dense_1', 'dense_2'] # Assert, whether the new layers are injected or not for l_name in all_layers: self.assertEqual(('%s_ts' % l_name) in model_loaded.forward_chain, is_prepared) self.assertEqual( isinstance(getattr(model_loaded, l_name), layers.MaskedLayer), is_prepared)
def prune_layer_and_eval(dataset_name='cifar10', pruning_methods=(('rand', True), ), model_dir=gin.REQUIRED, l_name=gin.REQUIRED, n_exps=gin.REQUIRED, val_size=gin.REQUIRED, max_pruning_fraction=gin.REQUIRED): """Loads and prunes a model with various sparsity targets. This function assumes that the `seed_dir` exists Args: dataset_name: str, either 'cifar10' or 'imagenet'. pruning_methods: iterator of tuples, (scoring, is_bp) where `is_bp` is a boolean indicating the usage of Mean Replacement Pruning. Scoring is a string from ['norm', 'mrs', 'rs', 'rand', 'abs_mrs', 'rs']. 'norm': unitscorers.norm_score '{abs_}mrs': `compute_mean_replacement_saliency=True` for `TaylorScorer` layer. if {abs_} prefix exists absolute value used before aggregation. i.e. is_abs=True. '{abs_}rs': `compute_removal_saliency=True` for `TaylorScorer` layer. if {abs_} prefix exists absolute value used before aggregation. i.e. is_abs=True. 'rand': unitscorers.random_score is_bp; bool, if True, mean value of the units are propagated to the next layer prior to the pruning. `bp` for bias_propagation. model_dir: str, Path to the checkpoint directory. l_name: str, a valid layer name from the model loaded. n_exps: int, number of pruning experiments to be made. This number is used generate pruning counts for different experiments. val_size: int, size for the first dataset, passed to the `get_datasets`. max_pruning_fraction: float, max sparsity for pruning. Multiplying this number with the total number of units, we would get the upper limit for the pruning_count. Raises: AssertionError: when no checkpoint is found. ValueError: when the scoring function key is not valid. OSError: when there is no checkpoint found. """ logging.info('Looking checkpoint at: %s', model_dir) latest_cpkt = tf.train.latest_checkpoint(model_dir) if not latest_cpkt: raise OSError('No checkpoint found in %s' % model_dir) logging.info('Using latest checkpoint at %s', latest_cpkt) model = model_load.get_model(load_path=latest_cpkt, dataset_name=dataset_name) datasets = data.get_datasets(dataset_name=dataset_name, val_size=val_size) _, _, subset_val, subset_test, subset_val2 = datasets input_shapes = {l_name: getattr(model, l_name + '_ts').xshape} layers2prune = [l_name] measurements = pruner.get_pruning_measurements(model, subset_val, layers2prune) (all_scores, mean_values, _) = measurements for scoring, is_bp in pruning_methods: if scoring not in pruner.ALL_SCORING_FUNCTIONS: raise ValueError('%s is not one of %s' % (scoring, pruner.ALL_SCORING_FUNCTIONS)) scores = all_scores[scoring] d_path = os.path.join( FLAGS.outdir, '%d-%s-%s-%s.pickle' % (val_size, l_name, scoring, str(is_bp))) logging.info(d_path) if tf.gfile.Exists(d_path): logging.warning('File %s exists, skipping.', d_path) else: ls_train_loss = [] ls_train_acc = [] ls_test_loss = [] ls_test_acc = [] n_units = input_shapes[l_name][-1].value n_unit_pruned_max = n_units * max_pruning_fraction c_slice = np.linspace(0, n_unit_pruned_max, n_exps, dtype=np.int32) logging.info('Layer:%s, n_units:%d, c_slice:%s', l_name, n_units, str(c_slice)) for pruning_count in c_slice: # Cast from np.int32 to int. pruning_count = int(pruning_count) copied_model = model.clone() pruning_factor = None pruner.prune_model_with_scores(copied_model, scores, is_bp, layers2prune, pruning_factor, pruning_count, mean_values, input_shapes) test_loss, test_acc, _ = train_utils.cross_entropy_loss( copied_model, subset_test, calculate_accuracy=True) train_loss, train_acc, _ = train_utils.cross_entropy_loss( copied_model, subset_val2, calculate_accuracy=True) logging.info('is_bp: %s, n: %d, test_loss%f, train_loss:%f', str(is_bp), pruning_count, test_loss, train_loss) ls_train_loss.append(train_loss.numpy()) ls_test_loss.append(test_loss.numpy()) ls_test_acc.append(test_acc.numpy()) ls_train_acc.append(train_acc.numpy()) utils.pickle_object( (ls_train_loss, ls_train_acc, ls_test_loss, ls_test_acc), d_path)
def prune_and_finetune_model(pruning_schedule=gin.REQUIRED, model_dir=gin.REQUIRED, dataset_name='cifar10', checkpoint_interval=5, log_interval=5, n_finetune=100, epochs=20, lr=1e-4, momentum=0.9): """Loads and prunes the model layer by layer according to the schedule. The model is finetuned between pruning tasks (as given in the schedule). Args: pruning_schedule: list<str, float>, where the str is a valid layer name and the float is the pruning fraction of that layer. Layers are pruned in the order they are given and `n_finetune` steps taken in between. model_dir: str, Path to the checkpoint directory. dataset_name: str, either 'cifar10' or 'imagenet'. checkpoint_interval: int, number of epochs between two checkpoints. log_interval: int, number of steps between two logging events. n_finetune: int, number of steps between two pruning steps. epochs: int, total number of epochs to run. lr: float, learning rate for the fine-tuning steps. momentum: float, momentum multiplier for the fine-tuning steps. Raises: ValueError: when the n_finetune is not positive. OSError: when there is no checkpoint found. """ if n_finetune <= 0: raise ValueError('n_finetune must be positive: given %d' % n_finetune) optimizer = tf.keras.optimizers.SGD(learning_rate=lr, momentum=momentum) logging.info('Using outdir: %s', model_dir) latest_cpkt = tf.train.latest_checkpoint(model_dir) if not latest_cpkt: raise OSError('No checkpoint found in %s' % model_dir) logging.info('Using latest checkpoint at %s', latest_cpkt) model = model_load.get_model(load_path=latest_cpkt, dataset_name=dataset_name) logging.info('Model init-config: %s', model.init_config) logging.info('Model forward chain: %s', str(model.forward_chain)) datasets = data.get_datasets(dataset_name=dataset_name) dataset_train, dataset_test, subset_val, subset_test, subset_val2 = datasets unit_pruner = pruner.UnitPruner(model, subset_val) step_counter = optimizer.iteration tf.summary.experimental.set_step(step_counter) current_epoch = tf.Variable(1) current_layer_index = tf.Variable(0) checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, current_epoch=current_epoch, current_layer_index=current_layer_index) latest_cpkt = tf.train.latest_checkpoint(FLAGS.outdir) if latest_cpkt: logging.info('Using latest checkpoint at %s', latest_cpkt) # Restore variables on creation if a checkpoint exists. checkpoint.restore(latest_cpkt) logging.info('Resuming with epoch: %d', current_epoch.numpy()) c_epoch = current_epoch.numpy() c_layer_index = current_layer_index.numpy() # Starting from the first batch, we perform pruning every `n_finetune` step. # Layers pruned one by one according to the pruning schedule given. while c_epoch <= epochs: logging.info('Starting Epoch: %d', c_epoch) for (x, y) in dataset_train: # Every `n_finetune` step perform pruning. if (step_counter.numpy() % n_finetune == 0 and c_layer_index < len(pruning_schedule)): logging.info('Pruning at iteration: %d', step_counter.numpy()) l_name, pruning_factor = pruning_schedule[c_layer_index] unit_pruner.prune_layer(l_name, pruning_factor=pruning_factor) train_utils.log_loss_acc(model, subset_val2, subset_test) train_utils.log_sparsity(model) # Re-init optimizer and therefore remove previous momentum. optimizer = tf.keras.optimizers.SGD(learning_rate=lr, momentum=momentum) c_layer_index += 1 current_layer_index.assign(c_layer_index) else: if step_counter.numpy() % log_interval == 0: logging.info('Iteration: %d', step_counter.numpy()) train_utils.log_loss_acc(model, subset_val2, subset_test) train_utils.log_sparsity(model) with tf.GradientTape() as tape: loss_train, _, _ = cross_entropy_loss(model, (x, y), training=True) grads = tape.gradient(loss_train, model.variables) # Updating the model. optimizer.apply_gradients(list(zip(grads, model.variables)), global_step=step_counter) if step_counter.numpy() % log_interval == 0: tf.summary.scalar('loss_train', loss_train) tf.summary.image('x', x, max_outputs=1) # End of an epoch. c_epoch += 1 current_epoch.assign(c_epoch) # Save every n OR after last epoch. if (tf.equal((current_epoch - 1) % checkpoint_interval, 0) or c_epoch > epochs): # Re-init checkpoint to ensure the masks are captured. The reason for # this is that the masks are initially not generated. checkpoint = tf.train.Checkpoint( optimizer=optimizer, model=model, current_epoch=current_epoch, current_layer_index=current_layer_index) logging.info('Checkpoint after epoch: %d', c_epoch - 1) checkpoint.save( os.path.join(FLAGS.outdir, 'ckpt-%d' % (c_epoch - 1))) # Test model test_loss, test_acc, n_samples = cross_entropy_loss( model, dataset_test, calculate_accuracy=True) tf.summary.scalar('test_loss_all', test_loss) tf.summary.scalar('test_acc_all', test_acc) logging.info( 'Overall_test_loss: %.4f, Overall_test_acc: %.4f, n_samples: %d', test_loss, test_acc, n_samples)
def train_model(dataset_name='cifar10', checkpoint_every_n_epoch=5, log_interval=1000, epochs=10, lr=1e-2, lr_drop_iter=1500, lr_decay=0.5, momentum=0.9, seed=8): """Trains the model with regular logging and checkpoints. Args: dataset_name: str, either 'cifar10' or 'imagenet'. checkpoint_every_n_epoch: int, number of epochs between two checkpoints. log_interval: int, number of steps between two logging events. epochs: int, epoch to train with. lr: float, learning rate for the fine-tuning steps. lr_drop_iter: int, iteration between two consequtive lr drop. lr_decay: float: multiplier for step learning rate reduction. momentum: float, momentum multiplier for the fine-tuning steps. seed: int, random seed to be set to produce reproducible experiments. Raises: AssertionError: when the args doesn't match the specs. """ assert dataset_name in ['cifar10', 'imagenet', 'cub200', 'imagenet_vgg'] tf.random.set_seed(seed) # The model is configured through gin parameters. model = model_load.get_model(dataset_name=dataset_name) # model.call = tf.contrib.eager.defun(model.call) datasets = data.get_datasets(dataset_name=dataset_name) dataset_train, dataset_test, _, subset_test, subset_val = datasets lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( lr, decay_steps=lr_drop_iter, decay_rate=lr_decay, staircase=True) optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=momentum) step_counter = optimizer.iterations tf.summary.experimental.set_step(step_counter) logging.info('Model init-config: %s', model.init_config) logging.info('Model forward chain: %s', str(model.forward_chain)) current_epoch = tf.Variable(1) # Create checkpoint object TODO check whether you need ckpt-prefix. checkpoint = tf.train.Checkpoint( optimizer=optimizer, model=model, current_epoch=current_epoch) # No limits basically by setting to `n_units_target`. checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=FLAGS.outdir, max_to_keep=None) latest_cpkt = checkpoint_manager.latest_checkpoint if latest_cpkt: logging.info('Using latest checkpoint at %s', latest_cpkt) # Restore variables on creation if a checkpoint exists. checkpoint.restore(latest_cpkt) logging.info('Resuming with epoch: %d', current_epoch.numpy()) c_epoch = current_epoch.numpy() while c_epoch <= epochs: logging.info('Starting Epoch:%d', c_epoch) for (x, y) in dataset_train: if step_counter % log_interval == 0: train_utils.log_loss_acc(model, subset_val, subset_test) tf.summary.image('x', x, max_outputs=1) logging.info('Iteration:%d', step_counter.numpy()) with tf.GradientTape() as tape: loss_train, _, _ = cross_entropy_loss(model, (x, y), training=True) grads = tape.gradient(loss_train, model.variables) # Updating the model. optimizer.apply_gradients(zip(grads, model.variables)) tf.summary.scalar('loss_train', loss_train) tf.summary.scalar('lr', optimizer.lr(step_counter)) # End of an epoch. c_epoch += 1 current_epoch.assign(c_epoch) # Save every n OR after last epoch. if (tf.equal((current_epoch - 1) % checkpoint_every_n_epoch, 0) or c_epoch > epochs): logging.info('Checkpoint after epoch: %d', c_epoch - 1) checkpoint_manager.save(checkpoint_number=c_epoch - 1) test_loss, test_acc, n_samples = cross_entropy_loss( model, dataset_test, calculate_accuracy=True) tf.summary.scalar('test_loss_all', test_loss) tf.summary.scalar('test_acc_all', test_acc) logging.info('Overall_test_loss:%.4f, Overall_test_acc:%.4f, n_samples:%d', test_loss, test_acc, n_samples)
def testValueError(self, dataset_name, model_arch_name): with self.assertRaises(ValueError): _ = model_load.get_model(model_arch_name=model_arch_name, dataset_name=dataset_name)
def prune_and_finetune_model(dataset_name='imagenet_vgg', flop_regularizer=0, n_units_target=4000, checkpoint_interval=5, log_interval=5, n_finetune=100, lr=1e-4, momentum=0.9, seed=8): """Trains the model with regular logging and checkpoints. Args: dataset_name: str, dataset to train on. flop_regularizer: float, multiplier for the flop regularization. If 0, no regularization is made during pruning. n_units_target: int, number of unit to prune. checkpoint_interval: int, number of epochs between two checkpoints. log_interval: int, number of steps between two logging events. n_finetune: int, number of steps between two pruning steps. Starting from the first iteration we prune 1 unit every `n_finetune` gradient update. lr: float, learning rate for the fine-tuning steps. momentum: float, momentum multiplier for the fine-tuning steps. seed: int, random seed to be set to produce reproducible experiments. Raises: ValueError: when the n_finetune is not positive. """ if n_finetune <= 0: raise ValueError('n_finetune must be positive: given %d' % n_finetune) tf.random.set_seed(seed) optimizer = tf.keras.optimizers.SGD(lr, momentum=momentum) # imagenet_vgg->imagenet dataset_basename = dataset_name.split('_')[0] model = model_load.get_model(dataset_name=dataset_basename) # Uncomment following if your model is (defunable). # model.call = tf.function(model.call) datasets = data.get_datasets(dataset_name=dataset_name) (dataset_train, _, subset_val, subset_test, subset_val2) = datasets logging.info('Model init-config: %s', model.init_config) logging.info('Model forward chain: %s', str(model.forward_chain)) unit_pruner = pruner.UnitPruner(model, subset_val) # We prune all conv layers. pruning_pool = _all_vgg_conv_layers baselines = { l_name: c_flop * flop_regularizer for l_name, c_flop in zip(pruning_pool, _vgg16_flop_regularizition) } step_counter = optimizer.iterations tf.summary.experimental.set_step(step_counter) c_pruning_step = tf.Variable(1) # Create checkpoint object TODO check whether you need ckpt-prefix. checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, c_pruning_step=c_pruning_step) # No limits basically by setting to `n_units_target`. checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory=FLAGS.outdir, max_to_keep=None) latest_cpkt = checkpoint_manager.latest_checkpoint if latest_cpkt: logging.info('Using latest checkpoint at %s', latest_cpkt) # Restore variables on creation if a checkpoint exists. checkpoint.restore(latest_cpkt) logging.info('Resuming with pruning step: %d', c_pruning_step.numpy()) pruning_step = c_pruning_step.numpy() while pruning_step <= n_units_target: for (x, y) in dataset_train: # Every `n_finetune` step perform pruning. if step_counter.numpy() % n_finetune == 0: if pruning_step > n_units_target: # This holds true when we prune last time and fine tune N many # iterations. We would break and the while loop above would break, # too. break tf.logging.info('Pruning Step:%d', pruning_step) start = time.time() unit_pruner.prune_one_unit(pruning_pool, baselines=baselines) end = time.time() tf.logging.info( '\nTrain time for Pruning Step #%d (step %d): %f', pruning_step, tf.train.get_or_create_global_step().numpy(), end - start) pruning_step += 1 c_pruning_step.assign(pruning_step) if tf.equal((pruning_step - 1) % checkpoint_interval, 0): checkpoint_manager.save() if step_counter.numpy() % log_interval == 0: train_utils.log_loss_acc(model, subset_val2, subset_test) train_utils.log_sparsity(model) with tf.GradientTape() as tape: loss_train, _, _ = cross_entropy_loss(model, (x, y), training=True) grads = tape.gradient(loss_train, model.variables) # Updating the model. optimizer.apply_gradients(list(zip(grads, model.variables)), global_step=step_counter) if step_counter.numpy() % log_interval == 0: tf.summary.scalar('loss_train', loss_train) tf.summary.image('x', x, max_outputs=1)