def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2): """Setups a trivial training procedure for sparse training.""" tf.reset_default_graph() optim = tf.train.GradientDescentOptimizer(0.1) sparse_optim = sparse_optimizers.SparseStaticOptimizer( optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac) x = tf.random.uniform((1, n_inp)) y = layers.masked_fully_connected(x, n_out, activation_fn=None) global_step = tf.train.get_or_create_global_step() weight = pruning.get_weights()[0] # There is one masked layer to be trained. mask = pruning.get_masks()[0] # Around half of the values of the mask is set to zero with `mask_update`. mask_update = tf.assign( mask, tf.constant( np.random.choice([0, 1], size=(n_inp, n_out), p=[1./2, 1./2]), dtype=tf.float32)) loss = tf.reduce_mean(y) global_step = tf.train.get_or_create_global_step() train_op = sparse_optim.minimize(loss, global_step) # Init sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) sess.run([mask_update]) return sess, train_op, mask, weight, global_step
def snip_op(): all_masks = pruning.get_masks() assigner = sparse_utils.get_mask_init_fn(all_masks, self._mask_init_method, self._default_sparsity, self._custom_sparsity_map, mask_fn=snip_fn) with ops.control_dependencies([assigner]): assign_op = state_ops.assign(self.is_snipped, True, name='assign_true_after_snipped') return assign_op
def check_global_sparsity(): """Add a summary for the weight sparsity.""" weight_masks = magnitude_pruning.get_masks() weights_per_layer = [] nonzero_per_layer = [] for mask in weight_masks: nonzero_per_layer.append(tf.reduce_sum(mask)) weights_per_layer.append(tf.size(mask)) total_nonzero = tf.add_n(nonzero_per_layer) total_weights = tf.add_n(weights_per_layer) sparsity = (1.0 - (tf.cast(total_nonzero, tf.float32) / tf.cast(total_weights, tf.float32))) tf.summary.scalar("global_weight_sparsity", sparsity)
def scaffold_fn(): """For initialization, passed to the estimator.""" if FLAGS.initial_value_checkpoint: initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint) all_masks = pruning.get_masks() assigner = sparse_utils.get_mask_init_fn( all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity, CUSTOM_SPARSITY_MAP) def init_fn(scaffold, session): """A callable for restoring variable from a checkpoint.""" del scaffold # Unused. session.run(assigner) return tf.train.Scaffold(init_fn=init_fn)
def testMaskedLSTMCell(self): expected_num_masks = 1 expected_num_rows = 2 * self.dim expected_num_cols = 4 * self.dim with self.cached_session(): inputs = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) c = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) h = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) state = tf_rnn_cells.LSTMStateTuple(c, h) lstm_cell = rnn_cells.MaskedLSTMCell(self.dim) lstm_cell(inputs, state) self.assertEqual(len(pruning.get_masks()), expected_num_masks) self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks) self.assertEqual(len(pruning.get_thresholds()), expected_num_masks) self.assertEqual(len(pruning.get_weights()), expected_num_masks) for mask in pruning.get_masks(): self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols)) for weight in pruning.get_weights(): self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
def metric_fn(labels, logits, cross_loss, reg_loss): """Calculate eval metrics.""" logging.info('In metric function') eval_metrics = {} predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) eval_metrics['top_5_eval_accuracy'] = tf.metrics.mean(in_top_5) eval_metrics['cross_loss'] = tf.metrics.mean(cross_loss) eval_metrics['reg_loss'] = tf.metrics.mean(reg_loss) eval_metrics['eval_accuracy'] = tf.metrics.accuracy( labels=labels, predictions=predictions) # If evaluating once lets also calculate sparsities. if FLAGS.mode == 'eval_once': sparsity_summaries = utils.mask_summaries(pruning.get_masks()) # We call mean on a scalar to create tensor, update_op pairs. sparsity_summaries = {k: tf.metrics.mean(v) for k, v in sparsity_summaries.items()} eval_metrics.update(sparsity_summaries) return eval_metrics
def _setup_graph(self, default_sparsity, mask_init_method, custom_sparsity_map, n_inp=3, n_out=5): """Setups a trivial training procedure for sparse training.""" tf.reset_default_graph() optim = tf.train.GradientDescentOptimizer(1e-3) sparse_optim = sparse_optimizers.SparseDNWOptimizer( optim, default_sparsity, mask_init_method, custom_sparsity_map=custom_sparsity_map) inp_values = np.arange(1, n_inp + 1) scale_vector_values = np.random.uniform(size=(n_out, )) - 0.5 # The gradient is the outer product of input and the output gradients. # Since the loss is sample sum the output gradient is equal to the scale # vector. expected_grads = np.outer(inp_values, scale_vector_values) x = tf.reshape(tf.constant(inp_values, dtype=tf.float32), (1, n_inp)) y = layers.masked_fully_connected(x, n_out, activation_fn=None) scale_vector = tf.constant(scale_vector_values, dtype=tf.float32) y = y * scale_vector loss = tf.reduce_sum(y) global_step = tf.train.get_or_create_global_step() grads_and_vars = sparse_optim.compute_gradients(loss) train_op = sparse_optim.apply_gradients(grads_and_vars, global_step=global_step) # Init sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) mask = pruning.get_masks()[0] weights = pruning.get_weights()[0] return (sess, train_op, (expected_grads, grads_and_vars), mask, weights)
def main(unused_args): tf.set_random_seed(FLAGS.seed) tf.get_variable_scope().set_use_resource(True) np.random.seed(FLAGS.seed) # Load the MNIST data and set up an iterator. mnist_data = input_data.read_data_sets(FLAGS.mnist, one_hot=False, validation_size=0) train_images = mnist_data.train.images test_images = mnist_data.test.images if FLAGS.input_mask_path: reader = tf.train.load_checkpoint(FLAGS.input_mask_path) input_mask = reader.get_tensor('layer1/mask') indices = np.sum(input_mask, axis=1) != 0 train_images = train_images[:, indices] test_images = test_images[:, indices] dataset = tf.data.Dataset.from_tensor_slices( (train_images, mnist_data.train.labels.astype(np.int32))) num_batches = mnist_data.train.images.shape[0] // FLAGS.batch_size dataset = dataset.shuffle(buffer_size=mnist_data.train.images.shape[0]) batched_dataset = dataset.repeat(FLAGS.num_epochs).batch(FLAGS.batch_size) iterator = batched_dataset.make_one_shot_iterator() test_dataset = tf.data.Dataset.from_tensor_slices( (test_images, mnist_data.test.labels.astype(np.int32))) num_test_images = mnist_data.test.images.shape[0] test_dataset = test_dataset.repeat(FLAGS.num_epochs).batch(num_test_images) test_iterator = test_dataset.make_one_shot_iterator() # Set up loss function. use_model_pruning = FLAGS.training_method != 'baseline' if FLAGS.network_type == 'fc': cross_entropy_train, _ = mnist_network_fc( iterator.get_next(), model_pruning=use_model_pruning) cross_entropy_test, accuracy_test = mnist_network_fc( test_iterator.get_next(), reuse=True, model_pruning=use_model_pruning) else: raise RuntimeError(FLAGS.network + ' is an unknown network type.') # Remove extra added ones. Current implementation adds the variables twice # to the collection. Improve this hacky thing. # TODO test the following with the convnet or any other network. if use_model_pruning: for k in ('masks', 'masked_weights', 'thresholds', 'kernel'): # del tf.get_collection_ref(k)[2] # del tf.get_collection_ref(k)[2] collection = tf.get_collection_ref(k) del collection[len(collection) // 2:] print(tf.get_collection_ref(k)) # Set up optimizer and update ops. global_step = tf.train.get_or_create_global_step() batch_per_epoch = mnist_data.train.images.shape[0] // FLAGS.batch_size if FLAGS.optimizer != 'adam': if not use_model_pruning: boundaries = [ int(round(s * batch_per_epoch)) for s in [60, 70, 80] ] else: boundaries = [ int(round(s * batch_per_epoch)) for s in [FLAGS.lr_drop_epoch, FLAGS.lr_drop_epoch + 20] ] learning_rate = tf.train.piecewise_constant( global_step, boundaries, values=[ FLAGS.learning_rate / (3.**i) for i in range(len(boundaries) + 1) ]) else: learning_rate = FLAGS.learning_rate if FLAGS.optimizer == 'adam': opt = tf.train.AdamOptimizer(FLAGS.learning_rate) elif FLAGS.optimizer == 'momentum': opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum, use_nesterov=FLAGS.use_nesterov) elif FLAGS.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer(learning_rate) else: raise RuntimeError(FLAGS.optimizer + ' is unknown optimizer type') custom_sparsities = { 'layer2': FLAGS.end_sparsity * FLAGS.sparsity_scale, 'layer3': FLAGS.end_sparsity * 0 } if FLAGS.training_method == 'set': # We override the train op to also update the mask. opt = sparse_optimizers.SparseSETOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif FLAGS.training_method == 'static': # We override the train op to also update the mask. opt = sparse_optimizers.SparseStaticOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif FLAGS.training_method == 'momentum': # We override the train op to also update the mask. opt = sparse_optimizers.SparseMomentumOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, grow_init=FLAGS.grow_init, drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False) elif FLAGS.training_method == 'rigl': # We override the train op to also update the mask. opt = sparse_optimizers.SparseRigLOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal, initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False) elif FLAGS.training_method == 'snip': opt = sparse_optimizers.SparseSnipOptimizer( opt, mask_init_method=FLAGS.mask_init_method, default_sparsity=FLAGS.end_sparsity, custom_sparsity_map=custom_sparsities, use_tpu=False) elif FLAGS.training_method in ('scratch', 'baseline', 'prune'): pass else: raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method) train_op = opt.minimize(cross_entropy_train, global_step=global_step) if FLAGS.training_method == 'prune': hparams_string = ( 'begin_pruning_step={0},sparsity_function_begin_step={0},' 'end_pruning_step={1},sparsity_function_end_step={1},' 'target_sparsity={2},pruning_frequency={3},' 'threshold_decay={4}'.format(FLAGS.prune_begin_step, FLAGS.prune_end_step, FLAGS.end_sparsity, FLAGS.pruning_frequency, FLAGS.threshold_decay)) pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string) pruning_hparams.set_hparam( 'weight_sparsity_map', ['{0}:{1}'.format(k, v) for k, v in custom_sparsities.items()]) print(pruning_hparams) pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) with tf.control_dependencies([train_op]): train_op = pruning_obj.conditional_mask_update_op() weight_sparsity_levels = pruning.get_weight_sparsity() global_sparsity = sparse_utils.calculate_sparsity(pruning.get_masks()) tf.summary.scalar('test_accuracy', accuracy_test) tf.summary.scalar('global_sparsity', global_sparsity) for k, v in zip(pruning.get_masks(), weight_sparsity_levels): tf.summary.scalar('sparsity/%s' % k.name, v) if FLAGS.training_method in ('prune', 'snip', 'baseline'): mask_init_op = tf.no_op() tf.logging.info('No mask is set, starting dense.') else: all_masks = pruning.get_masks() mask_init_op = sparse_utils.get_mask_init_fn(all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity, custom_sparsities) if FLAGS.save_model: saver = tf.train.Saver() init_op = tf.global_variables_initializer() hyper_params_string = '_'.join([ FLAGS.network_type, str(FLAGS.batch_size), str(FLAGS.learning_rate), str(FLAGS.momentum), FLAGS.optimizer, str(FLAGS.l2_scale), FLAGS.training_method, str(FLAGS.prune_begin_step), str(FLAGS.prune_end_step), str(FLAGS.end_sparsity), str(FLAGS.pruning_frequency), str(FLAGS.seed) ]) tf.io.gfile.makedirs(FLAGS.save_path) filename = os.path.join(FLAGS.save_path, hyper_params_string + '.txt') merged_summary_op = tf.summary.merge_all() # Run session. if not use_model_pruning: with tf.Session() as sess: summary_writer = tf.summary.FileWriter( FLAGS.save_path, graph=tf.get_default_graph()) print('Epoch', 'Epoch time', 'Test loss', 'Test accuracy') sess.run([init_op]) tic = time.time() with tf.io.gfile.GFile(filename, 'w') as outputfile: for i in range(FLAGS.num_epochs * num_batches): sess.run([train_op]) if (i % num_batches) == (-1 % num_batches): epoch_time = time.time() - tic loss, accuracy, summary = sess.run([ cross_entropy_test, accuracy_test, merged_summary_op ]) # Write logs at every test iteration. summary_writer.add_summary(summary, i) log_str = '%d, %.4f, %.4f, %.4f' % ( i // num_batches, epoch_time, loss, accuracy) print(log_str) print(log_str, file=outputfile) tic = time.time() if FLAGS.save_model: saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt')) else: with tf.Session() as sess: summary_writer = tf.summary.FileWriter( FLAGS.save_path, graph=tf.get_default_graph()) log_str = ','.join([ 'Epoch', 'Iteration', 'Test loss', 'Test accuracy', 'G_Sparsity', 'Sparsity Layer 0', 'Sparsity Layer 1' ]) sess.run(init_op) sess.run(mask_init_op) tic = time.time() mask_records = {} with tf.io.gfile.GFile(filename, 'w') as outputfile: print(log_str) print(log_str, file=outputfile) for i in range(FLAGS.num_epochs * num_batches): if (FLAGS.mask_record_frequency > 0 and i % FLAGS.mask_record_frequency == 0): mask_vals = sess.run(pruning.get_masks()) # Cast into bool to save space. mask_records[i] = [ a.astype(np.bool) for a in mask_vals ] sess.run([train_op]) weight_sparsity, global_sparsity_val = sess.run( [weight_sparsity_levels, global_sparsity]) if (i % num_batches) == (-1 % num_batches): epoch_time = time.time() - tic loss, accuracy, summary = sess.run([ cross_entropy_test, accuracy_test, merged_summary_op ]) # Write logs at every test iteration. summary_writer.add_summary(summary, i) log_str = '%d, %d, %.4f, %.4f, %.4f, %.4f, %.4f' % ( i // num_batches, i, loss, accuracy, global_sparsity_val, weight_sparsity[0], weight_sparsity[1]) print(log_str) print(log_str, file=outputfile) mask_vals = sess.run(pruning.get_masks()) if FLAGS.network_type == 'fc': sparsities, sizes = get_compressed_fc(mask_vals) print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' % (sparsities, sizes)) print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' % (sparsities, sizes), file=outputfile) tic = time.time() if FLAGS.save_model: saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt')) if mask_records: np.save(os.path.join(FLAGS.save_path, 'mask_records'), mask_records)
def train_function(training_method, loss, cross_loss, reg_loss, output_dir, use_tpu): """Training script for resnet model. Args: training_method: string indicating pruning method used to compress model. loss: tensor float32 of the cross entropy + regularization losses. cross_loss: tensor, only cross entropy loss, passed for logging. reg_loss: tensor, only regularization loss, passed for logging. output_dir: string tensor indicating the directory to save summaries. use_tpu: boolean indicating whether to run script on a tpu. Returns: host_call: summary tensors to be computed at each training step. train_op: the optimization term. """ global_step = tf.train.get_global_step() steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch) learning_rate = lr_schedule(current_epoch) if FLAGS.use_adam: # We don't use step decrease for the learning rate. learning_rate = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) else: optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=FLAGS.momentum, use_nesterov=True) if use_tpu: # use CrossShardOptimizer when using TPU. optimizer = contrib_tpu.CrossShardOptimizer(optimizer) if training_method == 'set': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseSETOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal, stateless_seed_offset=FLAGS.seed) elif training_method == 'static': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseStaticOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal, stateless_seed_offset=FLAGS.seed) elif training_method == 'momentum': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseMomentumOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, grow_init=FLAGS.grow_init, stateless_seed_offset=FLAGS.seed, drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=use_tpu) elif training_method == 'rigl': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseRigLOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, stateless_seed_offset=FLAGS.seed, drop_fraction_anneal=FLAGS.drop_fraction_anneal, initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=use_tpu) elif training_method == 'snip': optimizer = sparse_optimizers.SparseSnipOptimizer( optimizer, mask_init_method=FLAGS.mask_init_method, custom_sparsity_map=CUSTOM_SPARSITY_MAP, default_sparsity=FLAGS.end_sparsity, use_tpu=use_tpu) elif training_method in ('scratch', 'baseline'): pass else: raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method) # UPDATE_OPS needs to be added as a dependency due to batch norm update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops), tf.name_scope('train'): train_op = optimizer.minimize(loss, global_step) metrics = { 'global_step': tf.train.get_or_create_global_step(), 'loss': loss, 'cross_loss': cross_loss, 'reg_loss': reg_loss, 'learning_rate': learning_rate, 'current_epoch': current_epoch, } # Logging drop_fraction if dynamic sparse training. if training_method in ('set', 'momentum', 'rigl', 'static'): metrics['drop_fraction'] = optimizer.drop_fraction # Let's log some statistics from a single parameter-mask couple. # This is useful for debugging. test_var = pruning.get_weights()[0] test_var_mask = pruning.get_masks()[0] metrics.update({ 'fw_nz_weight': tf.count_nonzero(test_var), 'fw_nz_mask': tf.count_nonzero(test_var_mask), 'fw_l1_weight': tf.reduce_sum(tf.abs(test_var)) }) masks = pruning.get_masks() global_sparsity = sparse_utils.calculate_sparsity(masks) metrics['global_sparsity'] = global_sparsity metrics.update( utils.mask_summaries(masks[:4] + masks[-1:], with_img=FLAGS.log_mask_imgs_each_iteration)) host_call = (functools.partial(utils.host_call_fn, output_dir), utils.format_tensors(metrics)) return host_call, train_op
def train_function(pruning_method, loss, output_dir, use_tpu): """Training script for resnet model. Args: pruning_method: string indicating pruning method used to compress model. loss: tensor float32 of the cross entropy + regularization losses. output_dir: string tensor indicating the directory to save summaries. use_tpu: boolean indicating whether to run script on a tpu. Returns: host_call: summary tensors to be computed at each training step. train_op: the optimization term. """ global_step = tf.train.get_global_step() steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch) learning_rate = lr_schedule(current_epoch) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=FLAGS.momentum, use_nesterov=True) if use_tpu: # use CrossShardOptimizer when using TPU. optimizer = contrib_tpu.CrossShardOptimizer(optimizer) # UPDATE_OPS needs to be added as a dependency due to batch norm update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops), tf.name_scope('train'): train_op = optimizer.minimize(loss, global_step) if not use_tpu: if FLAGS.num_workers > 0: optimizer = tf.train.SyncReplicasOptimizer( optimizer, replicas_to_aggregate=FLAGS.num_workers, total_num_replicas=FLAGS.num_workers) optimizer.make_session_run_hook(True) metrics = { 'global_step': tf.train.get_or_create_global_step(), 'loss': loss, 'learning_rate': learning_rate, 'current_epoch': current_epoch } if pruning_method == 'threshold': # construct the necessary hparams string from the FLAGS hparams_string = ('begin_pruning_step={0},' 'sparsity_function_begin_step={0},' 'end_pruning_step={1},' 'sparsity_function_end_step={1},' 'target_sparsity={2},' 'pruning_frequency={3},' 'threshold_decay=0,' 'use_tpu={4}'.format( FLAGS.sparsity_begin_step, FLAGS.sparsity_end_step, FLAGS.end_sparsity, FLAGS.pruning_frequency, FLAGS.use_tpu, )) # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string) # The first layer has so few parameters, we don't need to prune it, and # pruning it a higher sparsity levels has very negative effects. if FLAGS.prune_first_layer and FLAGS.first_layer_sparsity >= 0.: pruning_hparams.set_hparam( 'weight_sparsity_map', ['resnet_model/initial_conv:%f' % FLAGS.first_layer_sparsity]) if FLAGS.prune_last_layer and FLAGS.last_layer_sparsity >= 0: pruning_hparams.set_hparam( 'weight_sparsity_map', ['resnet_model/final_dense:%f' % FLAGS.last_layer_sparsity]) # Create a pruning object using the pruning hyperparameters pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) # We override the train op to also update the mask. with tf.control_dependencies([train_op]): train_op = pruning_obj.conditional_mask_update_op() masks = pruning.get_masks() metrics.update(utils.mask_summaries(masks)) elif pruning_method == 'scratch': masks = pruning.get_masks() # make sure the masks have the sparsity we expect and that it doesn't change metrics.update(utils.mask_summaries(masks)) elif pruning_method == 'variational_dropout': masks = utils.add_vd_pruning_summaries( threshold=FLAGS.log_alpha_threshold) metrics.update(masks) elif pruning_method == 'l0_regularization': summaries = utils.add_l0_summaries() metrics.update(summaries) elif pruning_method == 'baseline': pass else: raise ValueError('Unsupported pruning method', FLAGS.pruning_method) host_call = (functools.partial(utils.host_call_fn, output_dir), utils.format_tensors(metrics)) return host_call, train_op
def get_masks(self): return pruning.get_masks()
def wide_resnet_w_pruning(features, labels, mode, params): """The model_fn for ResNet wide with pruning. Args: features: A float32 batch of images. labels: A int32 batch of labels. mode: Specifies whether training or evaluation. params: Dictionary of parameters passed to the model. Returns: A EstimatorSpec for the model Raises: ValueError: if mode is not recognized as train or eval. """ if isinstance(features, dict): features = features['feature'] train_dir = params['train_dir'] training_method = params['training_method'] global_step, accuracy, top_5_accuracy, logits = build_model( mode=mode, images=features, labels=labels, training_method=training_method, num_classes=FLAGS.num_classes, depth=FLAGS.resnet_depth, width=FLAGS.resnet_width) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) with tf.name_scope('computing_cross_entropy_loss'): entropy_loss = tf.losses.sparse_softmax_cross_entropy( labels=labels, logits=logits) tf.summary.scalar('cross_entropy_loss', entropy_loss) with tf.name_scope('computing_total_loss'): total_loss = tf.losses.get_total_loss(add_regularization_losses=True) if mode == tf.estimator.ModeKeys.TRAIN: hooks, eval_metrics, train_op = train_fn(training_method, global_step, total_loss, train_dir, accuracy, top_5_accuracy) elif mode == tf.estimator.ModeKeys.EVAL: hooks = None train_op = None with tf.name_scope('summaries'): eval_metrics = create_eval_metrics(labels, logits) else: raise ValueError('mode not recognized as training or eval.') if FLAGS.training_method in ('prune', 'snip', 'baseline'): scaffold = None tf.logging.info('No mask is set, starting dense.') else: all_masks = pruning.get_masks() assigner = sparse_utils.get_mask_init_fn( all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity, {}) def init_fn(scaffold, session): """A callable for restoring variable from a checkpoint.""" del scaffold # Unused. session.run(assigner) scaffold = tf.train.Scaffold(init_fn=init_fn) return tf.estimator.EstimatorSpec( mode=mode, training_hooks=hooks, loss=total_loss, train_op=train_op, eval_metric_ops=eval_metrics, scaffold=scaffold)
def train_fn(training_method, global_step, total_loss, train_dir, accuracy, top_5_accuracy): """Training script for resnet model. Args: training_method: specifies the method used to sparsify networks. global_step: the current step of training/eval. total_loss: tensor float32 of the cross entropy + regularization losses. train_dir: string specifying where directory where summaries are saved. accuracy: tensor float32 batch classification accuracy. top_5_accuracy: tensor float32 batch classification accuracy (top_5 classes). Returns: hooks: summary tensors to be computed at each training step. eval_metrics: set to None during training. train_op: the optimization term. """ # Rougly drops at every 30k steps. boundaries = [30000, 60000, 90000] if FLAGS.training_steps_multiplier != 1.0: multiplier = FLAGS.training_steps_multiplier boundaries = [int(x * multiplier) for x in boundaries] tf.logging.info( 'Learning Rate boundaries are updated with multiplier:%.2f', multiplier) learning_rate = tf.train.piecewise_constant( global_step, boundaries, values=[0.1 / (5.**i) for i in range(len(boundaries) + 1)], name='lr_schedule') optimizer = tf.train.MomentumOptimizer( learning_rate, momentum=FLAGS.momentum, use_nesterov=True) if training_method == 'set': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseSETOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif training_method == 'static': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseStaticOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif training_method == 'momentum': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseMomentumOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, grow_init=FLAGS.grow_init, drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False) elif training_method == 'rigl': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseRigLOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal, initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False) elif training_method == 'snip': optimizer = sparse_optimizers.SparseSnipOptimizer( optimizer, mask_init_method=FLAGS.mask_init_method, default_sparsity=FLAGS.end_sparsity, use_tpu=False) elif training_method in ('scratch', 'baseline', 'prune'): pass else: raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method) # Create the training op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(total_loss, global_step) if training_method == 'prune': # construct the necessary hparams string from the FLAGS hparams_string = ('begin_pruning_step={0},' 'sparsity_function_begin_step={0},' 'end_pruning_step={1},' 'sparsity_function_end_step={1},' 'target_sparsity={2},' 'pruning_frequency={3},' 'threshold_decay=0,' 'use_tpu={4}'.format( FLAGS.sparsity_begin_step, FLAGS.sparsity_end_step, FLAGS.end_sparsity, FLAGS.pruning_frequency, False, )) # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string) # Create a pruning object using the pruning hyperparameters pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) tf.logging.info('starting mask update op') # We override the train op to also update the mask. with tf.control_dependencies([train_op]): train_op = pruning_obj.conditional_mask_update_op() masks = pruning.get_masks() mask_metrics = utils.mask_summaries(masks) for name, tensor in mask_metrics.items(): tf.summary.scalar(name, tensor) tf.summary.scalar('learning_rate', learning_rate) tf.summary.scalar('accuracy', accuracy) tf.summary.scalar('total_loss', total_loss) tf.summary.scalar('top_5_accuracy', top_5_accuracy) # Logging drop_fraction if dynamic sparse training. if training_method in ('set', 'momentum', 'rigl', 'static'): tf.summary.scalar('drop_fraction', optimizer.drop_fraction) summary_op = tf.summary.merge_all() summary_hook = tf.train.SummarySaverHook( save_secs=300, output_dir=train_dir, summary_op=summary_op) hooks = [summary_hook] eval_metrics = None return hooks, eval_metrics, train_op