def testWeightSpecificSparsity(self): param_list = [ "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100", "target_sparsity=0.5", "weight_sparsity_map=[layer2/weights:0.75]", "threshold_decay=0.0" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) with variable_scope.variable_scope("layer1"): w1 = variables.Variable(math_ops.linspace(1.0, 100.0, 100), name="weights") _ = pruning.apply_mask(w1) with variable_scope.variable_scope("layer2"): w2 = variables.Variable(math_ops.linspace(1.0, 100.0, 100), name="weights") _ = pruning.apply_mask(w2) p = pruning.Pruning(pruning_hparams) mask_update_op = p.conditional_mask_update_op() increment_global_step = state_ops.assign_add(self.global_step, 1) with self.test_session() as session: variables.global_variables_initializer().run() for _ in range(110): session.run(mask_update_op) session.run(increment_global_step) self.assertAllEqual(session.run(pruning.get_weight_sparsity()), [0.5, 0.75])
def testPerLayerBlockSparsity(self): param_list = [ "block_dims_map=[layer1/weights:1x1,layer2/weights:1x2]", "block_pooling_function=AVG", "threshold_decay=0.0" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) with variable_scope.variable_scope("layer1"): w1 = constant_op.constant([[-0.1, 0.1], [-0.2, 0.2]], name="weights") pruning.apply_mask(w1) with variable_scope.variable_scope("layer2"): w2 = constant_op.constant( [[0.1, 0.1, 0.3, 0.3], [0.2, 0.2, 0.4, 0.4]], name="weights") pruning.apply_mask(w2) sparsity = variables.VariableV1(0.5, name="sparsity") p = pruning.Pruning(pruning_hparams, sparsity=sparsity) mask_update_op = p.mask_update_op() with self.cached_session() as session: variables.global_variables_initializer().run() session.run(mask_update_op) mask1_eval = session.run(pruning.get_masks()[0]) mask2_eval = session.run(pruning.get_masks()[1]) self.assertAllEqual(session.run(pruning.get_weight_sparsity()), [0.5, 0.5]) self.assertAllEqual(mask1_eval, [[0.0, 0.0], [1., 1.]]) self.assertAllEqual(mask2_eval, [[0, 0, 1., 1.], [0, 0, 1., 1.]])
def testWeightSpecificSparsity(self): param_list = [ "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100", "target_sparsity=0.5", "weight_sparsity_map=[layer2/weights:0.75]", "threshold_decay=0.0" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) with variable_scope.variable_scope("layer1"): w1 = variables.Variable( math_ops.linspace(1.0, 100.0, 100), name="weights") _ = pruning.apply_mask(w1) with variable_scope.variable_scope("layer2"): w2 = variables.Variable( math_ops.linspace(1.0, 100.0, 100), name="weights") _ = pruning.apply_mask(w2) p = pruning.Pruning(pruning_hparams) mask_update_op = p.conditional_mask_update_op() increment_global_step = state_ops.assign_add(self.global_step, 1) with self.cached_session() as session: variables.global_variables_initializer().run() for _ in range(110): session.run(mask_update_op) session.run(increment_global_step) self.assertAllEqual( session.run(pruning.get_weight_sparsity()), [0.5, 0.75])
def __init__(self, input_size, output_size, model_path: str, momentum=0.9, reg_str=0.0005, scope='ConvNet', pruning_start=int(10e4), pruning_end=int(10e5), pruning_freq=int(10), sparsity_start=0, sparsity_end=int(10e5), target_sparsity=0.0, dropout=0.5, initial_sparsity=0, wd=0.0): super(ConvNet, self).__init__(input_size=input_size, output_size=output_size, model_path=model_path) self.scope = scope self.momentum = momentum self.reg_str = reg_str self.dropout = dropout self.logger = get_logger(scope) self.wd = wd self.logger.info("creating graph...") with self.graph.as_default(): self.global_step = tf.Variable(0, trainable=False) self._build_placeholders() self.logits = self._build_model() self.weights_matrices = pruning.get_masked_weights() self.sparsity = pruning.get_weight_sparsity() self.loss = self._loss() self.train_op = self._optimizer() self._create_metrics() self.saver = tf.train.Saver(var_list=tf.global_variables()) self.hparams = pruning.get_pruning_hparams()\ .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},' ' sparsity_function_begin_step={},sparsity_function_end_step={},' 'pruning_frequency={},initial_sparsity={},' ' sparsity_function_exponent={}'.format(scope, pruning_start, pruning_end, target_sparsity, sparsity_start, sparsity_end, pruning_freq, initial_sparsity, 3)) # note that the global step plays an important part in the pruning mechanism, # the higher the global step the closer the sparsity is to sparsity end self.pruning_obj = pruning.Pruning(self.hparams, global_step=self.global_step) self.mask_update_op = self.pruning_obj.conditional_mask_update_op() # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP self.init_variables( tf.global_variables()) # initialize variables in graph
def __init__(self, actor_input_dim, actor_output_dim, model_path, redundancy=None, last_measure=10e4, tau=0.01): super(StudentActor, self).__init__(model_path=model_path) self.actor_input_dim = (None, actor_input_dim) self.actor_output_dim = (None, actor_output_dim) self.tau = tau self.redundancy = redundancy self.last_measure = last_measure with self.graph.as_default(): self.actor_global_step = tf.Variable(0, trainable=False) self._build_placeholders() self.actor_logits = self._build_actor() # self.gumbel_dist = self._build_gumbel(self.actor_logits) self.loss = self._build_loss() self.actor_parameters = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='actor') self.actor_pruned_weight_matrices = pruning.get_masked_weights() self.actor_train_op = self._build_actor_train_op() self.actor_saver = tf.train.Saver(var_list=self.actor_parameters, max_to_keep=100) self.init_variables(tf.global_variables()) self.sparsity = pruning.get_weight_sparsity() self.hparams = pruning.get_pruning_hparams() \ .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},' ' sparsity_function_begin_step={},sparsity_function_end_step={},' 'pruning_frequency={},initial_sparsity={},' ' sparsity_function_exponent={}'.format('Actor', cfg.pruning_start, cfg.pruning_end, cfg.target_sparsity, cfg.sparsity_start, cfg.sparsity_end, cfg.pruning_freq, cfg.initial_sparsity, 3)) # note that the global step plays an important part in the pruning mechanism, # the higher the global step the closer the sparsity is to sparsity end self.pruning_obj = pruning.Pruning( self.hparams, global_step=self.actor_global_step) self.mask_update_op = self.pruning_obj.conditional_mask_update_op() # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP self.init_variables( tf.global_variables()) # initialize variables in graph
def main(unused_args): tf.set_random_seed(FLAGS.seed) tf.get_variable_scope().set_use_resource(True) np.random.seed(FLAGS.seed) # Load the MNIST data and set up an iterator. mnist_data = input_data.read_data_sets(FLAGS.mnist, one_hot=False, validation_size=0) train_images = mnist_data.train.images test_images = mnist_data.test.images if FLAGS.input_mask_path: reader = tf.train.load_checkpoint(FLAGS.input_mask_path) input_mask = reader.get_tensor('layer1/mask') indices = np.sum(input_mask, axis=1) != 0 train_images = train_images[:, indices] test_images = test_images[:, indices] dataset = tf.data.Dataset.from_tensor_slices( (train_images, mnist_data.train.labels.astype(np.int32))) num_batches = mnist_data.train.images.shape[0] // FLAGS.batch_size dataset = dataset.shuffle(buffer_size=mnist_data.train.images.shape[0]) batched_dataset = dataset.repeat(FLAGS.num_epochs).batch(FLAGS.batch_size) iterator = batched_dataset.make_one_shot_iterator() test_dataset = tf.data.Dataset.from_tensor_slices( (test_images, mnist_data.test.labels.astype(np.int32))) num_test_images = mnist_data.test.images.shape[0] test_dataset = test_dataset.repeat(FLAGS.num_epochs).batch(num_test_images) test_iterator = test_dataset.make_one_shot_iterator() # Set up loss function. use_model_pruning = FLAGS.training_method != 'baseline' if FLAGS.network_type == 'fc': cross_entropy_train, _ = mnist_network_fc( iterator.get_next(), model_pruning=use_model_pruning) cross_entropy_test, accuracy_test = mnist_network_fc( test_iterator.get_next(), reuse=True, model_pruning=use_model_pruning) else: raise RuntimeError(FLAGS.network + ' is an unknown network type.') # Remove extra added ones. Current implementation adds the variables twice # to the collection. Improve this hacky thing. # TODO test the following with the convnet or any other network. if use_model_pruning: for k in ('masks', 'masked_weights', 'thresholds', 'kernel'): # del tf.get_collection_ref(k)[2] # del tf.get_collection_ref(k)[2] collection = tf.get_collection_ref(k) del collection[len(collection) // 2:] print(tf.get_collection_ref(k)) # Set up optimizer and update ops. global_step = tf.train.get_or_create_global_step() batch_per_epoch = mnist_data.train.images.shape[0] // FLAGS.batch_size if FLAGS.optimizer != 'adam': if not use_model_pruning: boundaries = [ int(round(s * batch_per_epoch)) for s in [60, 70, 80] ] else: boundaries = [ int(round(s * batch_per_epoch)) for s in [FLAGS.lr_drop_epoch, FLAGS.lr_drop_epoch + 20] ] learning_rate = tf.train.piecewise_constant( global_step, boundaries, values=[ FLAGS.learning_rate / (3.**i) for i in range(len(boundaries) + 1) ]) else: learning_rate = FLAGS.learning_rate if FLAGS.optimizer == 'adam': opt = tf.train.AdamOptimizer(FLAGS.learning_rate) elif FLAGS.optimizer == 'momentum': opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum, use_nesterov=FLAGS.use_nesterov) elif FLAGS.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer(learning_rate) else: raise RuntimeError(FLAGS.optimizer + ' is unknown optimizer type') custom_sparsities = { 'layer2': FLAGS.end_sparsity * FLAGS.sparsity_scale, 'layer3': FLAGS.end_sparsity * 0 } if FLAGS.training_method == 'set': # We override the train op to also update the mask. opt = sparse_optimizers.SparseSETOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif FLAGS.training_method == 'static': # We override the train op to also update the mask. opt = sparse_optimizers.SparseStaticOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif FLAGS.training_method == 'momentum': # We override the train op to also update the mask. opt = sparse_optimizers.SparseMomentumOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, grow_init=FLAGS.grow_init, drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False) elif FLAGS.training_method == 'rigl': # We override the train op to also update the mask. opt = sparse_optimizers.SparseRigLOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal, initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False) elif FLAGS.training_method == 'snip': opt = sparse_optimizers.SparseSnipOptimizer( opt, mask_init_method=FLAGS.mask_init_method, default_sparsity=FLAGS.end_sparsity, custom_sparsity_map=custom_sparsities, use_tpu=False) elif FLAGS.training_method in ('scratch', 'baseline', 'prune'): pass else: raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method) train_op = opt.minimize(cross_entropy_train, global_step=global_step) if FLAGS.training_method == 'prune': hparams_string = ( 'begin_pruning_step={0},sparsity_function_begin_step={0},' 'end_pruning_step={1},sparsity_function_end_step={1},' 'target_sparsity={2},pruning_frequency={3},' 'threshold_decay={4}'.format(FLAGS.prune_begin_step, FLAGS.prune_end_step, FLAGS.end_sparsity, FLAGS.pruning_frequency, FLAGS.threshold_decay)) pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string) pruning_hparams.set_hparam( 'weight_sparsity_map', ['{0}:{1}'.format(k, v) for k, v in custom_sparsities.items()]) print(pruning_hparams) pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) with tf.control_dependencies([train_op]): train_op = pruning_obj.conditional_mask_update_op() weight_sparsity_levels = pruning.get_weight_sparsity() global_sparsity = sparse_utils.calculate_sparsity(pruning.get_masks()) tf.summary.scalar('test_accuracy', accuracy_test) tf.summary.scalar('global_sparsity', global_sparsity) for k, v in zip(pruning.get_masks(), weight_sparsity_levels): tf.summary.scalar('sparsity/%s' % k.name, v) if FLAGS.training_method in ('prune', 'snip', 'baseline'): mask_init_op = tf.no_op() tf.logging.info('No mask is set, starting dense.') else: all_masks = pruning.get_masks() mask_init_op = sparse_utils.get_mask_init_fn(all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity, custom_sparsities) if FLAGS.save_model: saver = tf.train.Saver() init_op = tf.global_variables_initializer() hyper_params_string = '_'.join([ FLAGS.network_type, str(FLAGS.batch_size), str(FLAGS.learning_rate), str(FLAGS.momentum), FLAGS.optimizer, str(FLAGS.l2_scale), FLAGS.training_method, str(FLAGS.prune_begin_step), str(FLAGS.prune_end_step), str(FLAGS.end_sparsity), str(FLAGS.pruning_frequency), str(FLAGS.seed) ]) tf.io.gfile.makedirs(FLAGS.save_path) filename = os.path.join(FLAGS.save_path, hyper_params_string + '.txt') merged_summary_op = tf.summary.merge_all() # Run session. if not use_model_pruning: with tf.Session() as sess: summary_writer = tf.summary.FileWriter( FLAGS.save_path, graph=tf.get_default_graph()) print('Epoch', 'Epoch time', 'Test loss', 'Test accuracy') sess.run([init_op]) tic = time.time() with tf.io.gfile.GFile(filename, 'w') as outputfile: for i in range(FLAGS.num_epochs * num_batches): sess.run([train_op]) if (i % num_batches) == (-1 % num_batches): epoch_time = time.time() - tic loss, accuracy, summary = sess.run([ cross_entropy_test, accuracy_test, merged_summary_op ]) # Write logs at every test iteration. summary_writer.add_summary(summary, i) log_str = '%d, %.4f, %.4f, %.4f' % ( i // num_batches, epoch_time, loss, accuracy) print(log_str) print(log_str, file=outputfile) tic = time.time() if FLAGS.save_model: saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt')) else: with tf.Session() as sess: summary_writer = tf.summary.FileWriter( FLAGS.save_path, graph=tf.get_default_graph()) log_str = ','.join([ 'Epoch', 'Iteration', 'Test loss', 'Test accuracy', 'G_Sparsity', 'Sparsity Layer 0', 'Sparsity Layer 1' ]) sess.run(init_op) sess.run(mask_init_op) tic = time.time() mask_records = {} with tf.io.gfile.GFile(filename, 'w') as outputfile: print(log_str) print(log_str, file=outputfile) for i in range(FLAGS.num_epochs * num_batches): if (FLAGS.mask_record_frequency > 0 and i % FLAGS.mask_record_frequency == 0): mask_vals = sess.run(pruning.get_masks()) # Cast into bool to save space. mask_records[i] = [ a.astype(np.bool) for a in mask_vals ] sess.run([train_op]) weight_sparsity, global_sparsity_val = sess.run( [weight_sparsity_levels, global_sparsity]) if (i % num_batches) == (-1 % num_batches): epoch_time = time.time() - tic loss, accuracy, summary = sess.run([ cross_entropy_test, accuracy_test, merged_summary_op ]) # Write logs at every test iteration. summary_writer.add_summary(summary, i) log_str = '%d, %d, %.4f, %.4f, %.4f, %.4f, %.4f' % ( i // num_batches, i, loss, accuracy, global_sparsity_val, weight_sparsity[0], weight_sparsity[1]) print(log_str) print(log_str, file=outputfile) mask_vals = sess.run(pruning.get_masks()) if FLAGS.network_type == 'fc': sparsities, sizes = get_compressed_fc(mask_vals) print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' % (sparsities, sizes)) print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' % (sparsities, sizes), file=outputfile) tic = time.time() if FLAGS.save_model: saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt')) if mask_records: np.save(os.path.join(FLAGS.save_path, 'mask_records'), mask_records)