def train_function(pruning_method, loss, output_dir, use_tpu): """Training script for resnet model. Args: pruning_method: string indicating pruning method used to compress model. loss: tensor float32 of the cross entropy + regularization losses. output_dir: string tensor indicating the directory to save summaries. use_tpu: boolean indicating whether to run script on a tpu. Returns: host_call: summary tensors to be computed at each training step. train_op: the optimization term. """ global_step = tf.train.get_global_step() steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch) learning_rate = lr_schedule(current_epoch) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=FLAGS.momentum, use_nesterov=True) if use_tpu: # use CrossShardOptimizer when using TPU. optimizer = contrib_tpu.CrossShardOptimizer(optimizer) # UPDATE_OPS needs to be added as a dependency due to batch norm update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops), tf.name_scope('train'): train_op = optimizer.minimize(loss, global_step) if not use_tpu: if FLAGS.num_workers > 0: optimizer = tf.train.SyncReplicasOptimizer( optimizer, replicas_to_aggregate=FLAGS.num_workers, total_num_replicas=FLAGS.num_workers) optimizer.make_session_run_hook(True) metrics = { 'global_step': tf.train.get_or_create_global_step(), 'loss': loss, 'learning_rate': learning_rate, 'current_epoch': current_epoch } if pruning_method == 'threshold': # construct the necessary hparams string from the FLAGS hparams_string = ('begin_pruning_step={0},' 'sparsity_function_begin_step={0},' 'end_pruning_step={1},' 'sparsity_function_end_step={1},' 'target_sparsity={2},' 'pruning_frequency={3},' 'threshold_decay=0,' 'use_tpu={4}'.format( FLAGS.sparsity_begin_step, FLAGS.sparsity_end_step, FLAGS.end_sparsity, FLAGS.pruning_frequency, FLAGS.use_tpu, )) # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string) # The first layer has so few parameters, we don't need to prune it, and # pruning it a higher sparsity levels has very negative effects. if FLAGS.prune_first_layer and FLAGS.first_layer_sparsity >= 0.: pruning_hparams.set_hparam( 'weight_sparsity_map', ['resnet_model/initial_conv:%f' % FLAGS.first_layer_sparsity]) if FLAGS.prune_last_layer and FLAGS.last_layer_sparsity >= 0: pruning_hparams.set_hparam( 'weight_sparsity_map', ['resnet_model/final_dense:%f' % FLAGS.last_layer_sparsity]) # Create a pruning object using the pruning hyperparameters pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) # We override the train op to also update the mask. with tf.control_dependencies([train_op]): train_op = pruning_obj.conditional_mask_update_op() masks = pruning.get_masks() metrics.update(utils.mask_summaries(masks)) elif pruning_method == 'scratch': masks = pruning.get_masks() # make sure the masks have the sparsity we expect and that it doesn't change metrics.update(utils.mask_summaries(masks)) elif pruning_method == 'variational_dropout': masks = utils.add_vd_pruning_summaries( threshold=FLAGS.log_alpha_threshold) metrics.update(masks) elif pruning_method == 'l0_regularization': summaries = utils.add_l0_summaries() metrics.update(summaries) elif pruning_method == 'baseline': pass else: raise ValueError('Unsupported pruning method', FLAGS.pruning_method) host_call = (functools.partial(utils.host_call_fn, output_dir), utils.format_tensors(metrics)) return host_call, train_op
def testMultipleConvMaskAdded(self, pruning_method): tf.reset_default_graph() g = tf.Graph() with g.as_default(): number_of_layers = 5 kernel_size = [3, 3] base_depth = 4 depth_step = 7 input_tensor = tf.ones((8, self.height, self.width, base_depth)) top_layer = input_tensor for ix in range(number_of_layers): units = base_depth + (ix + 1) * depth_step top_layer = pruning_layers.sparse_conv2d( x=top_layer, units=units, kernel_size=kernel_size, is_training=False, sparsity_technique=pruning_method) if pruning_method == 'variational_dropout': theta_logsigma2 = tf.get_collection( vd.layers.THETA_LOGSIGMA2_COLLECTION) self.assertLen(theta_logsigma2, number_of_layers) utils.add_vd_pruning_summaries(theta_logsigma2, threshold=3.0) dkl_loss_1 = utils.variational_dropout_dkl_loss( reg_scalar=1, start_reg_ramp_up=0, end_reg_ramp_up=1000, warm_up=False, use_tpu=False) dkl_loss_1 = tf.reshape(dkl_loss_1, [1]) dkl_loss_2 = utils.variational_dropout_dkl_loss( reg_scalar=5, start_reg_ramp_up=0, end_reg_ramp_up=1000, warm_up=False, use_tpu=False) dkl_loss_2 = tf.reshape(dkl_loss_2, [1]) for ix in range(number_of_layers): self.assertListEqual(theta_logsigma2[ix][0].get_shape().as_list(), [ kernel_size[0], kernel_size[1], base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step ]) init_op = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init_op) if pruning_method == 'variational_dropout': loss_1, loss_2 = sess.run([dkl_loss_1, dkl_loss_2]) self.assertGreater(loss_2, loss_1) elif pruning_method == 'l0_regularization': theta_logalpha = tf.get_collection( l0.layers.THETA_LOGALPHA_COLLECTION) self.assertLen(theta_logalpha, number_of_layers) utils.add_l0_summaries(theta_logalpha) l0_norm_loss_1 = utils.l0_regularization_loss( reg_scalar=1, start_reg_ramp_up=0, end_reg_ramp_up=1000, warm_up=False, use_tpu=False) l0_norm_loss_1 = tf.reshape(l0_norm_loss_1, [1]) l0_norm_loss_2 = utils.l0_regularization_loss( reg_scalar=5, start_reg_ramp_up=0, end_reg_ramp_up=1000, warm_up=False, use_tpu=False) l0_norm_loss_2 = tf.reshape(l0_norm_loss_2, [1]) for ix in range(number_of_layers): self.assertListEqual(theta_logalpha[ix][0].get_shape().as_list(), [ kernel_size[0], kernel_size[1], base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step ]) init_op = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init_op) loss_1, loss_2 = sess.run([l0_norm_loss_1, l0_norm_loss_2]) self.assertGreater(loss_2, loss_1) else: mask = tf.get_collection(core.MASK_COLLECTION) for ix in range(number_of_layers): self.assertListEqual(mask[ix].get_shape().as_list(), [ kernel_size[0], kernel_size[1], base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step ])