class FullPrecLearner(AbstractLearner): # pylint: disable=too-many-instance-attributes """Full-precision learner (no model compression applied).""" def __init__(self, sm_writer, model_helper, model_scope=None, enbl_dst=None): """Constructor function. Args: * sm_writer: TensorFlow's summary writer * model_helper: model helper with definitions of model & dataset * model_scope: name scope in which to define the model * enbl_dst: whether to create a model with distillation loss """ # class-independent initialization super(FullPrecLearner, self).__init__(sm_writer, model_helper) # over-ride the model scope and distillation loss switch if model_scope is not None: self.model_scope = model_scope self.enbl_dst = enbl_dst if enbl_dst is not None else FLAGS.enbl_dst # class-dependent initialization if self.enbl_dst: self.helper_dst = DistillationHelper(sm_writer, model_helper, self.mpi_comm) self.__build(is_train=True) self.__build(is_train=False) def train(self): """Train a model and periodically produce checkpoint files.""" # initialization self.sess_train.run(self.init_op) self.warm_start(self.sess_train) if FLAGS.enbl_multi_gpu: self.sess_train.run(self.bcast_op) # train the model through iterations and periodically save & evaluate the model time_prev = timer() for idx_iter in range(self.nb_iters_train): # train the model if (idx_iter + 1) % FLAGS.summ_step != 0: self.sess_train.run(self.train_op) else: __, summary, log_rslt = self.sess_train.run([self.train_op, self.summary_op, self.log_op]) if self.is_primary_worker('global'): time_step = timer() - time_prev self.__monitor_progress(summary, log_rslt, idx_iter, time_step) time_prev = timer() # save & evaluate the model at certain steps if self.is_primary_worker('global') and (idx_iter + 1) % FLAGS.save_step == 0: self.__save_model(is_train=True) self.evaluate() # save the final model if self.is_primary_worker('global'): self.__save_model(is_train=True) self.__restore_model(is_train=False) self.__save_model(is_train=False) self.evaluate() def evaluate(self): """Restore a model from the latest checkpoint files and then evaluate it.""" self.__restore_model(is_train=False) nb_iters = int(np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval)) eval_rslts = np.zeros((nb_iters, len(self.eval_op))) self.dump_n_eval(outputs=None, action='init') for idx_iter in range(nb_iters): eval_rslts[idx_iter], outputs = self.sess_eval.run([self.eval_op, self.outputs_eval]) self.dump_n_eval(outputs=outputs, action='dump') self.dump_n_eval(outputs=None, action='eval') for idx, name in enumerate(self.eval_op_names): tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx]))) def __build(self, is_train): # pylint: disable=too-many-locals """Build the training / evaluation graph. Args: * is_train: whether to create the training graph """ with tf.Graph().as_default(): # TensorFlow session config = tf.ConfigProto() config.gpu_options.visible_device_list = str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() if is_train else self.build_dataset_eval() images, labels = iterator.get_next() tf.add_to_collection('images_final', images) # model definition - distilled model if self.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - primary model with tf.variable_scope(self.model_scope): # forward pass logits = self.forward_train(images) if is_train else self.forward_eval(images) tf.add_to_collection('logits_final', logits) # loss & extra evalution metrics loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if self.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) tf.summary.scalar('loss', loss) for key, value in metrics.items(): tf.summary.scalar(key, value) # optimizer & gradients if is_train: self.global_step = tf.train.get_or_create_global_step() lrn_rate, self.nb_iters_train = self.setup_lrn_rate(self.global_step) optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) if FLAGS.enbl_multi_gpu: optimizer = mgw.DistributedOptimizer(optimizer) grads = optimizer.compute_gradients(loss, self.trainable_vars) # TF operations & model saver if is_train: self.sess_train = sess with tf.control_dependencies(self.update_ops): self.train_op = optimizer.apply_gradients(grads, global_step=self.global_step) self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss] + list(metrics.values()) self.log_op_names = ['lr', 'loss'] + list(metrics.keys()) self.init_op = tf.variables_initializer(self.vars) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) self.saver_train = tf.train.Saver(self.vars) else: self.sess_eval = sess self.eval_op = [loss] + list(metrics.values()) self.eval_op_names = ['loss'] + list(metrics.keys()) self.outputs_eval = logits self.saver_eval = tf.train.Saver(self.vars) def __save_model(self, is_train): """Save the model to checkpoint files for training or evaluation. Args: * is_train: whether to save a model for training """ if is_train: save_path = self.saver_train.save(self.sess_train, FLAGS.save_path, self.global_step) else: save_path = self.saver_eval.save(self.sess_eval, FLAGS.save_path_eval) tf.logging.info('model saved to ' + save_path) def __restore_model(self, is_train): """Restore a model from the latest checkpoint files. Args: * is_train: whether to restore a model for training """ save_path = tf.train.latest_checkpoint(os.path.dirname(FLAGS.save_path)) if is_train: self.saver_train.restore(self.sess_train, save_path) else: self.saver_eval.restore(self.sess_eval, save_path) tf.logging.info('model restored from ' + save_path) def __monitor_progress(self, summary, log_rslt, idx_iter, time_step): """Monitor the training progress. Args: * summary: summary protocol buffer * log_rslt: logging operations' results * idx_iter: index of the training iteration * time_step: time step between two summary operations """ # write summaries for TensorBoard visualization self.sm_writer.add_summary(summary, idx_iter) # compute the training speed speed = FLAGS.batch_size * FLAGS.summ_step / time_step if FLAGS.enbl_multi_gpu: speed *= mgw.size() # display monitored statistics log_str = ' | '.join(['%s = %.4e' % (name, value) for name, value in zip(self.log_op_names, log_rslt)]) tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' % (idx_iter + 1, log_str, speed))
class UniformQuantTFLearner(AbstractLearner): # pylint: disable=too-many-instance-attributes """Uniform quantization learner with TensorFlow's quantization APIs.""" def __init__(self, sm_writer, model_helper): """Constructor function. Args: * sm_writer: TensorFlow's summary writer * model_helper: model helper with definitions of model & dataset """ # class-independent initialization super(UniformQuantTFLearner, self).__init__(sm_writer, model_helper) # define scopes for full & uniform quantized models self.model_scope_full = 'model' self.model_scope_quan = 'quant_model' # download the pre-trained model if self.is_primary_worker('local'): self.download_model() # pre-trained model is required self.auto_barrier() tf.logging.info('model files: ' + ', '.join(os.listdir('./models'))) # detect unquantized activations nodes self.unquant_node_names = [] if FLAGS.uqtf_enbl_manual_quant: self.unquant_node_names = find_unquant_act_nodes( model_helper, self.data_scope, self.model_scope_quan, self.mpi_comm) tf.logging.info('unquantized activation nodes: {}'.format( self.unquant_node_names)) # class-dependent initialization if FLAGS.enbl_dst: self.helper_dst = DistillationHelper(sm_writer, model_helper, self.mpi_comm) self.__build_train() self.__build_eval() def train(self): """Train a model and periodically produce checkpoint files.""" # restore the full model from pre-trained checkpoints save_path = tf.train.latest_checkpoint( os.path.dirname(self.save_path_full)) self.saver_full.restore(self.sess_train, save_path) # initialization self.sess_train.run([self.init_op, self.init_opt_op]) if FLAGS.enbl_multi_gpu: self.sess_train.run(self.bcast_op) # train the model through iterations and periodically save & evaluate the model time_prev = timer() for idx_iter in range(self.nb_iters_train): # train the model if (idx_iter + 1) % FLAGS.summ_step != 0: self.sess_train.run(self.train_op) else: __, summary, log_rslt = self.sess_train.run( [self.train_op, self.summary_op, self.log_op]) if self.is_primary_worker('global'): time_step = timer() - time_prev self.__monitor_progress(summary, log_rslt, idx_iter, time_step) time_prev = timer() # save the model at certain steps if self.is_primary_worker('global') and (idx_iter + 1) % FLAGS.save_step == 0: self.__save_model(is_train=True) self.evaluate() self.auto_barrier() # save the final model if self.is_primary_worker('global'): self.__save_model(is_train=True) self.__restore_model(is_train=False) self.__save_model(is_train=False) self.evaluate() def evaluate(self): """Restore a model from the latest checkpoint files and then evaluate it.""" self.__restore_model(is_train=False) nb_iters = int( np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval)) eval_rslts = np.zeros((nb_iters, len(self.eval_op))) self.dump_n_eval(outputs=None, action='init') for idx_iter in range(nb_iters): if (idx_iter + 1) % 100 == 0: tf.logging.info('process the %d-th mini-batch for evaluation' % (idx_iter + 1)) eval_rslts[idx_iter], outputs = self.sess_eval.run( [self.eval_op, self.outputs_eval]) self.dump_n_eval(outputs=outputs, action='dump') self.dump_n_eval(outputs=None, action='eval') for idx, name in enumerate(self.eval_op_names): tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx]))) def __build_train(self): # pylint: disable=too-many-locals,too-many-statements """Build the training graph.""" with tf.Graph().as_default() as graph: # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member config.gpu_options.allow_growth = True # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() # model definition - uniform quantized model - part 1 with tf.variable_scope(self.model_scope_quan): logits_quan = self.forward_train(images) if not isinstance(logits_quan, dict): outputs = tf.nn.softmax(logits_quan) else: outputs = tf.nn.softmax(logits_quan['cls_pred']) tf.contrib.quantize.experimental_create_training_graph( weight_bits=FLAGS.uqtf_weight_bits, activation_bits=FLAGS.uqtf_activation_bits, quant_delay=FLAGS.uqtf_quant_delay, freeze_bn_delay=FLAGS.uqtf_freeze_bn_delay, scope=self.model_scope_quan) for node_name in self.unquant_node_names: insert_quant_op(graph, node_name, is_train=True) self.vars_quan = get_vars_by_scope(self.model_scope_quan) self.global_step = tf.train.get_or_create_global_step() self.saver_quan_train = tf.train.Saver(self.vars_quan['all'] + [self.global_step]) # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - full model with tf.variable_scope(self.model_scope_full): __ = self.forward_train(images) self.vars_full = get_vars_by_scope(self.model_scope_full) self.saver_full = tf.train.Saver(self.vars_full['all']) self.save_path_full = FLAGS.save_path # model definition - uniform quantized model - part 2 with tf.variable_scope(self.model_scope_quan): # loss & extra evaluation metrics loss_bsc, metrics = self.calc_loss(labels, logits_quan, self.vars_quan['trainable']) if not FLAGS.enbl_dst: loss_fnl = loss_bsc else: loss_fnl = loss_bsc + self.helper_dst.calc_loss( logits_quan, logits_dst) tf.summary.scalar('loss_bsc', loss_bsc) tf.summary.scalar('loss_fnl', loss_fnl) for key, value in metrics.items(): tf.summary.scalar(key, value) # learning rate schedule lrn_rate, self.nb_iters_train = self.setup_lrn_rate( self.global_step) lrn_rate *= FLAGS.uqtf_lrn_rate_dcy # decrease the learning rate by a constant factor # if self.dataset_name == 'cifar_10': # lrn_rate *= 1e-3 # elif self.dataset_name == 'ilsvrc_12': # lrn_rate *= 1e-4 # else: # raise ValueError('unrecognized dataset\'s name: ' + self.dataset_name) # obtain the full list of trainable variables & update operations self.vars_all = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_quan) self.trainable_vars_all = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.model_scope_quan) self.update_ops_all = tf.get_collection( tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_quan) # TF operations for initializing the uniform quantized model init_ops = [] with tf.control_dependencies( [tf.variables_initializer(self.vars_all)]): for var_full, var_quan in zip(self.vars_full['all'], self.vars_quan['all']): init_ops += [var_quan.assign(var_full)] init_ops += [self.global_step.initializer] self.init_op = tf.group(init_ops) # TF operations for fine-tuning # optimizer_base = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) optimizer_base = tf.train.AdamOptimizer(lrn_rate) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) grads = optimizer.compute_gradients(loss_fnl, self.trainable_vars_all) with tf.control_dependencies(self.update_ops_all): self.train_op = optimizer.apply_gradients( grads, global_step=self.global_step) self.init_opt_op = tf.variables_initializer( optimizer_base.variables()) # TF operations for logging & summarizing self.sess_train = sess self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss_fnl] + list(metrics.values()) self.log_op_names = ['lr', 'loss'] + list(metrics.keys()) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) def __build_eval(self): """Build the evaluation graph.""" with tf.Graph().as_default() as graph: # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member config.gpu_options.allow_growth = True # pylint: disable=no-member self.sess_eval = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_eval() images, labels = iterator.get_next() # model definition - uniform quantized model - part 1 with tf.variable_scope(self.model_scope_quan): logits = self.forward_eval(images) if not isinstance(logits, dict): outputs = tf.nn.softmax(logits) else: outputs = tf.nn.softmax(logits['cls_pred']) tf.contrib.quantize.experimental_create_eval_graph( weight_bits=FLAGS.uqtf_weight_bits, activation_bits=FLAGS.uqtf_activation_bits, scope=self.model_scope_quan) for node_name in self.unquant_node_names: insert_quant_op(graph, node_name, is_train=False) vars_quan = get_vars_by_scope(self.model_scope_quan) global_step = tf.train.get_or_create_global_step() self.saver_quan_eval = tf.train.Saver(vars_quan['all'] + [global_step]) # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_eval, images) # model definition - uniform quantized model -part 2 with tf.variable_scope(self.model_scope_quan): # loss & extra evaluation metrics loss, metrics = self.calc_loss(labels, logits, vars_quan['trainable']) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) # TF operations for evaluation self.eval_op = [loss] + list(metrics.values()) self.eval_op_names = ['loss'] + list(metrics.keys()) self.outputs_eval = logits # add input & output tensors to certain collections if not isinstance(images, dict): tf.add_to_collection('images_final', images) else: tf.add_to_collection('images_final', images['image']) if not isinstance(logits, dict): tf.add_to_collection('logits_final', logits) else: tf.add_to_collection('logits_final', logits['cls_pred']) def __save_model(self, is_train): """Save the current model for training or evaluation. Args: * is_train: whether to save a model for training """ if is_train: save_path = self.saver_quan_train.save(self.sess_train, FLAGS.uqtf_save_path, self.global_step) else: save_path = self.saver_quan_eval.save(self.sess_eval, FLAGS.uqtf_save_path_eval) tf.logging.info('model saved to ' + save_path) def __restore_model(self, is_train): """Restore a model from the latest checkpoint files. Args: * is_train: whether to restore a model for training """ save_path = tf.train.latest_checkpoint( os.path.dirname(FLAGS.uqtf_save_path)) if is_train: self.saver_quan_train.restore(self.sess_train, save_path) else: self.saver_quan_eval.restore(self.sess_eval, save_path) tf.logging.info('model restored from ' + save_path) def __monitor_progress(self, summary, log_rslt, idx_iter, time_step): """Monitor the training progress. Args: * summary: summary protocol buffer * log_rslt: logging operations' results * idx_iter: index of the training iteration * time_step: time step between two summary operations """ # write summaries for TensorBoard visualization self.sm_writer.add_summary(summary, idx_iter) # compute the training speed speed = FLAGS.batch_size * FLAGS.summ_step / time_step if FLAGS.enbl_multi_gpu: speed *= mgw.size() # display monitored statistics log_str = ' | '.join([ '%s = %.4e' % (name, value) for name, value in zip(self.log_op_names, log_rslt) ]) tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' % (idx_iter + 1, log_str, speed))
class WeightSparseLearner(AbstractLearner): # pylint: disable=too-many-instance-attributes """Weight sparsification learner.""" def __init__(self, sm_writer, model_helper): """Constructor function. Args: * sm_writer: TensorFlow's summary writer * model_helper: model helper with definitions of model & dataset """ # class-independent initialization super(WeightSparseLearner, self).__init__(sm_writer, model_helper) # define the scope for masks self.mask_scope = 'mask' # compute the optimal pruning ratios pr_optimizer = PROptimizer(model_helper, self.mpi_comm) if FLAGS.ws_prune_ratio_prtl == 'optimal': if self.is_primary_worker('local'): self.download_model() # pre-trained model is required self.auto_barrier() tf.logging.info('model files: ' + ', '.join(os.listdir('./models'))) self.var_names_n_prune_ratios = pr_optimizer.run() # class-dependent initialization if FLAGS.enbl_dst: self.helper_dst = DistillationHelper(sm_writer, model_helper, self.mpi_comm) self.__build_train() self.__build_eval() def train(self): """Train a model and periodically produce checkpoint files.""" # initialization self.sess_train.run(self.init_op) if FLAGS.enbl_multi_gpu: self.sess_train.run(self.bcast_op) # train the model through iterations and periodically save & evaluate the model last_mask_applied = False time_prev = timer() for idx_iter in range(self.nb_iters_train): # train the model if (idx_iter + 1) % FLAGS.summ_step != 0: self.sess_train.run(self.train_op) else: __, summary, log_rslt = self.sess_train.run( [self.train_op, self.summary_op, self.log_op]) if self.is_primary_worker('global'): time_step = timer() - time_prev self.__monitor_progress(summary, log_rslt, idx_iter, time_step) time_prev = timer() # apply pruning if (idx_iter + 1) % FLAGS.ws_mask_update_step == 0: iter_ratio = float(idx_iter + 1) / self.nb_iters_train if iter_ratio >= FLAGS.ws_iter_ratio_beg: if iter_ratio <= FLAGS.ws_iter_ratio_end: self.sess_train.run([self.prune_op, self.init_opt_op]) elif not last_mask_applied: last_mask_applied = True self.sess_train.run([self.prune_op, self.init_opt_op]) # save the model at certain steps if self.is_primary_worker('global') and (idx_iter + 1) % FLAGS.save_step == 0: self.__save_model() self.evaluate() # save the final model if self.is_primary_worker('global'): self.__save_model() self.evaluate() def evaluate(self): """Restore a model from the latest checkpoint files and then evaluate it.""" self.__restore_model(is_train=False) nb_iters = int(np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size)) eval_rslts = np.zeros((nb_iters, len(self.eval_op))) for idx_iter in range(nb_iters): eval_rslts[idx_iter] = self.sess_eval.run(self.eval_op) for idx, name in enumerate(self.eval_op_names): tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx]))) def __build_train(self): # pylint: disable=too-many-locals """Build the training graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() if FLAGS.enbl_multi_gpu: config.gpu_options.visible_device_list = str(mgw.local_rank()) # pylint: disable=no-member else: config.gpu_options.visible_device_list = '0' # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - weight-sparsified model with tf.variable_scope(self.model_scope): # loss & extra evaluation metrics logits = self.forward_train(images) self.maskable_var_names = [ var.name for var in self.maskable_vars ] loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) tf.summary.scalar('loss', loss) for key, value in metrics.items(): tf.summary.scalar(key, value) # learning rate schedule self.global_step = tf.train.get_or_create_global_step() lrn_rate, self.nb_iters_train = setup_lrn_rate( self.global_step, self.model_name, self.dataset_name) # overall pruning ratios of trainable & maskable variables pr_trainable = calc_prune_ratio(self.trainable_vars) pr_maskable = calc_prune_ratio(self.maskable_vars) tf.summary.scalar('pr_trainable', pr_trainable) tf.summary.scalar('pr_maskable', pr_maskable) # build masks and corresponding operations for weight sparsification self.masks, self.prune_op = self.__build_masks() # optimizer & gradients optimizer_base = tf.train.MomentumOptimizer( lrn_rate, FLAGS.momentum) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) grads_origin = optimizer.compute_gradients( loss, self.trainable_vars) grads_pruned = self.__calc_grads_pruned(grads_origin) # TF operations & model saver self.sess_train = sess with tf.control_dependencies(self.update_ops): self.train_op = optimizer.apply_gradients( grads_pruned, global_step=self.global_step) self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss, pr_trainable, pr_maskable] + list( metrics.values()) self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list( metrics.keys()) self.init_op = tf.variables_initializer(self.vars) self.init_opt_op = tf.variables_initializer( optimizer_base.variables()) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) self.saver_train = tf.train.Saver(self.vars) def __build_eval(self): """Build the evaluation graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() if FLAGS.enbl_multi_gpu: config.gpu_options.visible_device_list = str(mgw.local_rank()) # pylint: disable=no-member else: config.gpu_options.visible_device_list = '0' # pylint: disable=no-member self.sess_eval = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_eval() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_eval, images) # model definition - weight-sparsified model with tf.variable_scope(self.model_scope): # loss & extra evaluation metrics logits = self.forward_eval(images) loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) # overall pruning ratios of trainable & maskable variables pr_trainable = calc_prune_ratio(self.trainable_vars) pr_maskable = calc_prune_ratio(self.maskable_vars) # TF operations for evaluation self.eval_op = [loss, pr_trainable, pr_maskable] + list( metrics.values()) self.eval_op_names = ['loss', 'pr_trn', 'pr_msk'] + list( metrics.keys()) self.saver_eval = tf.train.Saver(self.vars) def __build_masks(self): """build masks and corresponding operations for weight sparsification. Returns: * masks: list of masks for weight sparsification * prune_op: pruning operation """ masks, prune_ops = [], [] with tf.variable_scope(self.mask_scope): for var, var_name_n_prune_ratio in zip( self.maskable_vars, self.var_names_n_prune_ratios): # obtain the dynamic pruning ratio assert var.name == var_name_n_prune_ratio[0], \ 'unmatched variable names: %s vs. %s' % (var.name, var_name_n_prune_ratio[0]) prune_ratio = self.__calc_prune_ratio_dyn( var_name_n_prune_ratio[1]) # create a mask and non-masked backup for each variable name = var.name.replace(':0', '_mask') mask = tf.get_variable(name, initializer=tf.ones(var.shape), trainable=False) name = var.name.replace(':0', '_var_bkup') var_bkup = tf.get_variable(name, initializer=var.initialized_value(), trainable=False) # create update operations mask_thres = tf.contrib.distributions.percentile( tf.abs(var), prune_ratio * 100) var_bkup_update_op = var_bkup.assign( tf.where(mask > 0.5, var, var_bkup)) with tf.control_dependencies([var_bkup_update_op]): mask_update_op = mask.assign( tf.cast(tf.abs(var) > mask_thres, tf.float32)) with tf.control_dependencies([mask_update_op]): prune_op = var.assign(var_bkup * mask) # record pruning masks & operations masks += [mask] prune_ops += [prune_op] return masks, tf.group(prune_ops) def __calc_prune_ratio_dyn(self, prune_ratio_fnl): """Calculate the dynamic pruning ratio. Args: * prune_ratio_fnl: final pruning ratio Returns: * prune_ratio_dyn: dynamic pruning ratio """ idx_iter_beg = int(self.nb_iters_train * FLAGS.ws_iter_ratio_beg) idx_iter_end = int(self.nb_iters_train * FLAGS.ws_iter_ratio_end) base = tf.cast(self.global_step - idx_iter_beg, tf.float32) / (idx_iter_end - idx_iter_beg) base = tf.minimum(1.0, tf.maximum(0.0, base)) prune_ratio_dyn = prune_ratio_fnl * ( 1.0 - tf.pow(1.0 - base, FLAGS.ws_prune_ratio_exp)) return prune_ratio_dyn def __calc_grads_pruned(self, grads_origin): """Calculate the mask-pruned gradients. Args: * grads_origin: list of original gradients Returns: * grads_pruned: list of mask-pruned gradients """ grads_pruned = [] for grad in grads_origin: if grad[1].name not in self.maskable_var_names: grads_pruned += [grad] else: idx_mask = self.maskable_var_names.index(grad[1].name) grads_pruned += [(grad[0] * self.masks[idx_mask], grad[1])] return grads_pruned def __save_model(self): """Save the current model.""" save_path = self.saver_train.save(self.sess_train, FLAGS.ws_save_path, self.global_step) tf.logging.info('model saved to ' + save_path) def __restore_model(self, is_train): """Restore a model from the latest checkpoint files. Args: * is_train: whether to restore a model for training """ save_path = tf.train.latest_checkpoint( os.path.dirname(FLAGS.ws_save_path)) if is_train: self.saver_train.restore(self.sess_train, save_path) else: self.saver_eval.restore(self.sess_eval, save_path) tf.logging.info('model restored from ' + save_path) def __monitor_progress(self, summary, log_rslt, idx_iter, time_step): """Monitor the training progress. Args: * summary: summary protocol buffer * log_rslt: logging operations' results * idx_iter: index of the training iteration * time_step: time step between two summary operations """ # write summaries for TensorBoard visualization self.sm_writer.add_summary(summary, idx_iter) # compute the training speed speed = FLAGS.batch_size * FLAGS.summ_step / time_step if FLAGS.enbl_multi_gpu: speed *= mgw.size() # display monitored statistics log_str = ' | '.join([ '%s = %.4e' % (name, value) for name, value in zip(self.log_op_names, log_rslt) ]) tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' % (idx_iter + 1, log_str, speed)) @property def maskable_vars(self): """List of all maskable variables.""" return get_maskable_vars(self.trainable_vars)
class DisChnPrunedLearner(AbstractLearner): # pylint: disable=too-many-instance-attributes """Discrimination-aware channel pruning learner.""" def __init__(self, sm_writer, model_helper): """Constructor function. Args: * sm_writer: TensorFlow's summary writer * model_helper: model helper with definitions of model & dataset """ # class-independent initialization super(DisChnPrunedLearner, self).__init__(sm_writer, model_helper) # define scopes for full & channel-pruned models self.model_scope_full = 'model' self.model_scope_prnd = 'pruned_model' # download the pre-trained model if self.is_primary_worker('local'): self.download_model() # pre-trained model is required self.auto_barrier() tf.logging.info('model files: ' + ', '.join(os.listdir('./models'))) # class-dependent initialization if FLAGS.enbl_dst: self.helper_dst = DistillationHelper(sm_writer, model_helper, self.mpi_comm) self.__build_train() self.__build_eval() def train(self): """Train a model and periodically produce checkpoint files.""" # restore the full model from pre-trained checkpoints save_path = tf.train.latest_checkpoint(os.path.dirname(self.save_path_full)) self.saver_full.restore(self.sess_train, save_path) # initialization self.sess_train.run([self.init_op, self.init_opt_op]) self.sess_train.run(self.layer_init_opt_ops) # initialization for layer-wise fine-tuning self.sess_train.run(self.block_init_opt_ops) # initialization for block-wise fine-tuning if FLAGS.enbl_multi_gpu: self.sess_train.run(self.bcast_op) # choose discrimination-aware channels self.__choose_discr_chns() # fine-tune the model with chosen channels only time_prev = timer() for idx_iter in range(self.nb_iters_train): # train the model if (idx_iter + 1) % FLAGS.summ_step != 0: self.sess_train.run(self.train_op) else: __, summary, log_rslt = self.sess_train.run([self.train_op, self.summary_op, self.log_op]) if self.is_primary_worker('global'): time_step = timer() - time_prev self.__monitor_progress(summary, log_rslt, idx_iter, time_step) time_prev = timer() # save the model at certain steps if self.is_primary_worker('global') and (idx_iter + 1) % FLAGS.save_step == 0: self.__save_model(is_train=True) self.evaluate() # save the final model if self.is_primary_worker('global'): self.__save_model(is_train=True) self.__restore_model(is_train=False) self.__save_model(is_train=False) self.evaluate() def evaluate(self): """Restore a model from the latest checkpoint files and then evaluate it.""" self.__restore_model(is_train=False) nb_iters = int(np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval)) eval_rslts = np.zeros((nb_iters, len(self.eval_op))) for idx_iter in range(nb_iters): eval_rslts[idx_iter] = self.sess_eval.run(self.eval_op) for idx, name in enumerate(self.eval_op_names): tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx]))) def __build_train(self): # pylint: disable=too-many-locals,too-many-statements """Build the training graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - full model with tf.variable_scope(self.model_scope_full): __ = self.forward_train(images) self.vars_full = get_vars_by_scope(self.model_scope_full) self.saver_full = tf.train.Saver(self.vars_full['all']) self.save_path_full = FLAGS.save_path # model definition - channel-pruned model with tf.variable_scope(self.model_scope_prnd): logits_prnd = self.forward_train(images) self.vars_prnd = get_vars_by_scope(self.model_scope_prnd) self.maskable_var_names = [var.name for var in self.vars_prnd['maskable']] self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all']) # loss & extra evaluation metrics loss_bsc, metrics = self.calc_loss(labels, logits_prnd, self.vars_prnd['trainable']) if not FLAGS.enbl_dst: loss_fnl = loss_bsc else: loss_fnl = loss_bsc + self.helper_dst.calc_loss(logits_prnd, logits_dst) tf.summary.scalar('loss_bsc', loss_bsc) tf.summary.scalar('loss_fnl', loss_fnl) for key, value in metrics.items(): tf.summary.scalar(key, value) # learning rate schedule self.global_step = tf.train.get_or_create_global_step() lrn_rate, self.nb_iters_train = self.setup_lrn_rate(self.global_step) # overall pruning ratios of trainable & maskable variables pr_trainable = calc_prune_ratio(self.vars_prnd['trainable']) pr_maskable = calc_prune_ratio(self.vars_prnd['maskable']) tf.summary.scalar('pr_trainable', pr_trainable) tf.summary.scalar('pr_maskable', pr_maskable) # create masks and corresponding operations for channel pruning self.masks = [] self.mask_deltas = [] self.mask_init_ops = [] self.mask_updt_ops = [] self.prune_ops = [] for idx, var in enumerate(self.vars_prnd['maskable']): name = '/'.join(var.name.split('/')[1:]).replace(':0', '_mask') self.masks += [tf.get_variable(name, initializer=tf.ones(var.shape), trainable=False)] name = '/'.join(var.name.split('/')[1:]).replace(':0', '_mask_delta') self.mask_deltas += [tf.placeholder(tf.float32, shape=var.shape, name=name)] self.mask_init_ops += [self.masks[idx].assign(tf.zeros(var.shape))] self.mask_updt_ops += [self.masks[idx].assign_add(self.mask_deltas[idx])] self.prune_ops += [var.assign(var * self.masks[idx])] # build extra losses for regression & discrimination self.reg_losses, self.dis_losses, self.idxs_layer_to_block = \ self.__build_extra_losses(labels) self.dis_losses += [loss_bsc] # append discrimination-aware loss for the last block self.nb_layers = len(self.reg_losses) self.nb_blocks = len(self.dis_losses) for idx, reg_loss in enumerate(self.reg_losses): tf.summary.scalar('reg_loss_%d' % idx, reg_loss) for idx, dis_loss in enumerate(self.dis_losses): tf.summary.scalar('dis_loss_%d' % idx, dis_loss) # obtain the full list of trainable variables & update operations self.vars_all = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_prnd) self.trainable_vars_all = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.model_scope_prnd) self.update_ops_all = tf.get_collection( tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_prnd) # TF operations for initializing the channel-pruned model init_ops = [] with tf.control_dependencies([tf.variables_initializer(self.vars_all)]): for var_full, var_prnd in zip(self.vars_full['all'], self.vars_prnd['all']): init_ops += [var_prnd.assign(var_full)] self.init_op = tf.group(init_ops) # TF operations for layer-wise, block-wise, and whole-network fine-tuning self.layer_train_ops, self.layer_init_opt_ops, self.grad_norms = self.__build_layer_ops() self.block_train_ops, self.block_init_opt_ops = self.__build_block_ops() self.train_op, self.init_opt_op = self.__build_network_ops(loss_fnl, lrn_rate) # TF operations for logging & summarizing self.sess_train = sess self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss_fnl, pr_trainable, pr_maskable] + list(metrics.values()) self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list(metrics.keys()) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) def __build_eval(self): """Build the evaluation graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member self.sess_eval = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_eval() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(self.sess_eval, images) # model definition - channel-pruned model with tf.variable_scope(self.model_scope_prnd): # loss & extra evaluation metrics logits = self.forward_eval(images) vars_prnd = get_vars_by_scope(self.model_scope_prnd) loss, metrics = self.calc_loss(labels, logits, vars_prnd['trainable']) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) # overall pruning ratios of trainable & maskable variables pr_trainable = calc_prune_ratio(vars_prnd['trainable']) pr_maskable = calc_prune_ratio(vars_prnd['maskable']) # TF operations for evaluation self.eval_op = [loss, pr_trainable, pr_maskable] + list(metrics.values()) self.eval_op_names = ['loss', 'pr_trn', 'pr_msk'] + list(metrics.keys()) self.saver_prnd_eval = tf.train.Saver(vars_prnd['all']) # add input & output tensors to certain collections tf.add_to_collection('images_final', images) tf.add_to_collection('logits_final', logits) def __build_extra_losses(self, labels): """Build extra losses for regression & discrimination. Args: * labels: one-hot label vectors Returns: * reg_losses: list of regression losses (one per layer) * dis_losses: list of discrimination-aware losses (one per layer) * idxs_layer_to_block: list of mappings from layer index to block index """ # insert additional losses to intermediate layers pattern = re.compile('Conv2D$') core_ops_full = get_ops_by_scope_n_pattern(self.model_scope_full, pattern) core_ops_prnd = get_ops_by_scope_n_pattern(self.model_scope_prnd, pattern) nb_layers = len(core_ops_full) nb_blocks = int(FLAGS.dcp_nb_stages + 1) nb_layers_per_block = int(math.ceil((nb_layers + 1) / nb_blocks)) reg_losses = [] dis_losses = [] idxs_layer_to_block = [] for idx_layer in range(nb_layers): reg_losses += \ [tf.nn.l2_loss(core_ops_full[idx_layer].outputs[0] - core_ops_prnd[idx_layer].outputs[0])] idxs_layer_to_block += [int(idx_layer / nb_layers_per_block)] if (idx_layer + 1) % nb_layers_per_block == 0: x = core_ops_prnd[idx_layer].outputs[0] x = tf.layers.batch_normalization(x, axis=3, training=True) x = tf.nn.relu(x) x = tf.reduce_mean(x, axis=[1, 2]) x = tf.layers.dense(x, FLAGS.nb_classes) dis_losses += [tf.losses.softmax_cross_entropy(labels, x)] tf.logging.info('layer-to-block mapping: {}'.format(idxs_layer_to_block)) return reg_losses, dis_losses, idxs_layer_to_block def __build_layer_ops(self): """Build layer-wise fine-tuning operations. Returns: * layer_train_ops: list of training operations for each layer * layer_init_opt_ops: list of initialization operations for each layer's optimizer * layer_grad_norms: list of gradient norm vectors for each layer """ layer_train_ops = [] layer_init_opt_ops = [] grad_norms = [] for idx, var_prnd in enumerate(self.vars_prnd['maskable']): optimizer_base = tf.train.AdamOptimizer(FLAGS.dcp_lrn_rate_adam) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) loss_all = self.reg_losses[idx] + self.dis_losses[self.idxs_layer_to_block[idx]] grads_origin = optimizer.compute_gradients(loss_all, [var_prnd]) grads_pruned = self.__calc_grads_pruned(grads_origin) with tf.control_dependencies(self.update_ops_all): layer_train_ops += [optimizer.apply_gradients(grads_pruned)] layer_init_opt_ops += [tf.variables_initializer(optimizer_base.variables())] grad_norms += [tf.reduce_sum(grads_origin[0][0] ** 2, axis=[0, 1, 3])] return layer_train_ops, layer_init_opt_ops, grad_norms def __build_block_ops(self): """Build block-wise fine-tuning operations. Returns: * block_train_ops: list of training operations for each block * block_init_opt_ops: list of initialization operations for each block's optimizer """ block_train_ops = [] block_init_opt_ops = [] for dis_loss in self.dis_losses: optimizer_base = tf.train.AdamOptimizer(FLAGS.dcp_lrn_rate_adam) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) loss_all = dis_loss + self.dis_losses[-1] # current stage + final loss grads_origin = optimizer.compute_gradients(loss_all, self.trainable_vars_all) grads_pruned = self.__calc_grads_pruned(grads_origin) with tf.control_dependencies(self.update_ops_all): block_train_ops += [optimizer.apply_gradients(grads_pruned)] block_init_opt_ops += [tf.variables_initializer(optimizer_base.variables())] return block_train_ops, block_init_opt_ops def __build_network_ops(self, loss, lrn_rate): """Build network training operations. Returns: * train_op: training operation of the whole network * init_opt_op: initialization operation of the whole network's optimizer """ optimizer_base = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) loss_all = tf.add_n(self.dis_losses[:-1]) * 0 + loss # all stages + final loss grads_origin = optimizer.compute_gradients(loss_all, self.trainable_vars_all) grads_pruned = self.__calc_grads_pruned(grads_origin) with tf.control_dependencies(self.update_ops_all): train_op = optimizer.apply_gradients(grads_pruned, global_step=self.global_step) init_opt_op = tf.variables_initializer(optimizer_base.variables()) return train_op, init_opt_op def __calc_grads_pruned(self, grads_origin): """Calculate the mask-pruned gradients. Args: * grads_origin: list of original gradients Returns: * grads_pruned: list of mask-pruned gradients """ grads_pruned = [] for grad in grads_origin: if grad[1].name not in self.maskable_var_names: grads_pruned += [grad] else: idx_mask = self.maskable_var_names.index(grad[1].name) grads_pruned += [(grad[0] * self.masks[idx_mask], grad[1])] return grads_pruned def __choose_discr_chns(self): # pylint: disable=too-many-locals """Choose discrimination-aware channels.""" # select the most discriminative channels through multiple stages nb_workers = mgw.size() if FLAGS.enbl_multi_gpu else 1 nb_iters_block = int(FLAGS.dcp_nb_iters_block / nb_workers) nb_iters_layer = int(FLAGS.dcp_nb_iters_layer / nb_workers) for idx_block in range(self.nb_blocks): # fine-tune the current block for idx_iter in range(nb_iters_block): if (idx_iter + 1) % FLAGS.summ_step != 0: self.sess_train.run(self.block_train_ops[idx_block]) else: summary, __ = self.sess_train.run([self.summary_op, self.block_train_ops[idx_block]]) if self.is_primary_worker('global'): self.sm_writer.add_summary(summary, nb_iters_block * idx_block + idx_iter) # select the most discriminative channels for each layer for idx_layer in range(1, self.nb_layers): # do not prune the first layer if self.idxs_layer_to_block[idx_layer] != idx_block: continue # initialize the mask as all channels are pruned mask_shape = self.sess_train.run(tf.shape(self.masks[idx_layer])) tf.logging.info('layer #{}: mask\'s shape is {}'.format(idx_layer, mask_shape)) nb_chns = mask_shape[2] grad_norm_mask = np.ones(nb_chns) mask_vec = np.sum(self.sess_train.run(self.masks[idx_layer]), axis=(0, 1, 3)) prune_ratio = 1.0 - float(np.count_nonzero(mask_vec)) / mask_vec.size tf.logging.info('layer #%d: prune_ratio = %.4f' % (idx_layer, prune_ratio)) is_first_entry = True while is_first_entry or prune_ratio > FLAGS.dcp_prune_ratio: # choose the most important channel and then update the mask grad_norm = self.sess_train.run(self.grad_norms[idx_layer]) idx_chn_input = np.argmax(grad_norm * grad_norm_mask) grad_norm_mask[idx_chn_input] = 0.0 tf.logging.info('adding channel #%d to the non-pruned set' % idx_chn_input) mask_delta = np.zeros(mask_shape) mask_delta[:, :, idx_chn_input, :] = 1.0 if is_first_entry: is_first_entry = False self.sess_train.run(self.mask_init_ops[idx_layer]) self.sess_train.run(self.mask_updt_ops[idx_layer], feed_dict={self.mask_deltas[idx_layer]: mask_delta}) self.sess_train.run(self.prune_ops[idx_layer]) # fine-tune the current layer for idx_iter in range(nb_iters_layer): self.sess_train.run(self.layer_train_ops[idx_layer]) # re-compute the pruning ratio mask_vec = np.sum(self.sess_train.run(self.masks[idx_layer]), axis=(0, 1, 3)) prune_ratio = 1.0 - float(np.count_nonzero(mask_vec)) / mask_vec.size tf.logging.info('layer #%d: prune_ratio = %.4f' % (idx_layer, prune_ratio)) # compute overall pruning ratios if self.is_primary_worker('global'): log_rslt = self.sess_train.run(self.log_op) log_str = ' | '.join(['%s = %.4e' % (name, value) for name, value in zip(self.log_op_names, log_rslt)]) tf.logging.info('block #%d: %s' % (idx_block + 1, log_str)) def __save_model(self, is_train): """Save the current model for training or evaluation. Args: * is_train: whether to save a model for training """ if is_train: save_path = self.saver_prnd_train.save(self.sess_train, FLAGS.dcp_save_path, self.global_step) else: save_path = self.saver_prnd_eval.save(self.sess_eval, FLAGS.dcp_save_path_eval) tf.logging.info('model saved to ' + save_path) def __restore_model(self, is_train): """Restore a model from the latest checkpoint files. Args: * is_train: whether to restore a model for training """ save_path = tf.train.latest_checkpoint(os.path.dirname(FLAGS.dcp_save_path)) if is_train: self.saver_prnd_train.restore(self.sess_train, save_path) else: self.saver_prnd_eval.restore(self.sess_eval, save_path) tf.logging.info('model restored from ' + save_path) def __monitor_progress(self, summary, log_rslt, idx_iter, time_step): """Monitor the training progress. Args: * summary: summary protocol buffer * log_rslt: logging operations' results * idx_iter: index of the training iteration * time_step: time step between two summary operations """ # write summaries for TensorBoard visualization self.sm_writer.add_summary(summary, idx_iter) # compute the training speed speed = FLAGS.batch_size * FLAGS.summ_step / time_step if FLAGS.enbl_multi_gpu: speed *= mgw.size() # display monitored statistics log_str = ' | '.join(['%s = %.4e' % (name, value) for name, value in zip(self.log_op_names, log_rslt)]) tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' % (idx_iter + 1, log_str, speed))
class ChannelPrunedGpuLearner(AbstractLearner): # pylint: disable=too-many-instance-attributes """Channel pruning learner with GPU-based optimization.""" def __init__(self, sm_writer, model_helper): """Constructor function. Args: * sm_writer: TensorFlow's summary writer * model_helper: model helper with definitions of model & dataset """ # class-independent initialization super(ChannelPrunedGpuLearner, self).__init__(sm_writer, model_helper) # define scopes for full & channel-pruned models self.model_scope_full = 'model' self.model_scope_prnd = 'pruned_model' # download the pre-trained model if self.is_primary_worker('local'): self.download_model() # pre-trained model is required self.auto_barrier() tf.logging.info('model files: ' + ', '.join(os.listdir('./models'))) # class-dependent initialization if FLAGS.enbl_dst: self.helper_dst = DistillationHelper(sm_writer, model_helper, self.mpi_comm) self.__build_train() self.__build_eval() def train(self): """Train a model and periodically produce checkpoint files.""" # restore the full model from pre-trained checkpoints save_path = tf.train.latest_checkpoint( os.path.dirname(self.save_path_full)) self.saver_full.restore(self.sess_train, save_path) # initialization self.sess_train.run([self.init_op, self.init_opt_op]) self.sess_train.run( [layer_op['init_opt'] for layer_op in self.layer_ops]) if FLAGS.enbl_multi_gpu: self.sess_train.run(self.bcast_op) # choose channels and evaluate the model before re-training self.__choose_channels() if self.is_primary_worker('global'): self.__save_model(is_train=True) self.evaluate() self.auto_barrier() # fine-tune the model with chosen channels only time_prev = timer() for idx_iter in range(self.nb_iters_train): # train the model if (idx_iter + 1) % FLAGS.summ_step != 0: self.sess_train.run(self.train_op) else: __, summary, log_rslt = self.sess_train.run( [self.train_op, self.summary_op, self.log_op]) if self.is_primary_worker('global'): time_step = timer() - time_prev self.__monitor_progress(summary, log_rslt, idx_iter, time_step) time_prev = timer() # save the model at certain steps if self.is_primary_worker('global') and (idx_iter + 1) % FLAGS.save_step == 0: self.__save_model(is_train=True) self.evaluate() self.auto_barrier() # save the final model if self.is_primary_worker('global'): self.__save_model(is_train=True) self.__restore_model(is_train=False) self.__save_model(is_train=False) self.evaluate() def evaluate(self): """Restore a model from the latest checkpoint files and then evaluate it.""" self.__restore_model(is_train=False) nb_iters = int( np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval)) eval_rslts = np.zeros((nb_iters, len(self.eval_op))) for idx_iter in range(nb_iters): eval_rslts[idx_iter] = self.sess_eval.run(self.eval_op) for idx, name in enumerate(self.eval_op_names): tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx]))) def __build_train(self): # pylint: disable=too-many-locals,too-many-statements """Build the training graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - full model with tf.variable_scope(self.model_scope_full): __ = self.forward_train(images) self.vars_full = get_vars_by_scope(self.model_scope_full) self.saver_full = tf.train.Saver(self.vars_full['all']) self.save_path_full = FLAGS.save_path # model definition - channel-pruned model with tf.variable_scope(self.model_scope_prnd): logits_prnd = self.forward_train(images) self.vars_prnd = get_vars_by_scope(self.model_scope_prnd) self.maskable_var_names = [ var.name for var in self.vars_prnd['maskable'] ] self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all']) # loss & extra evaluation metrics loss, metrics = self.calc_loss(labels, logits_prnd, self.vars_prnd['trainable']) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits_prnd, logits_dst) tf.summary.scalar('loss', loss) for key, value in metrics.items(): tf.summary.scalar(key, value) # learning rate schedule self.global_step = tf.train.get_or_create_global_step() lrn_rate, self.nb_iters_train = self.setup_lrn_rate( self.global_step) # overall pruning ratios of trainable & maskable variables pr_trainable = calc_prune_ratio(self.vars_prnd['trainable']) pr_maskable = calc_prune_ratio(self.vars_prnd['maskable']) tf.summary.scalar('pr_trainable', pr_trainable) tf.summary.scalar('pr_maskable', pr_maskable) # create masks and corresponding operations for channel pruning self.masks = [] self.mask_updt_ops = [] for idx, var in enumerate(self.vars_prnd['maskable']): tf.logging.info( 'creating a pruning mask for {} of size {}'.format( var.name, var.shape)) name = '/'.join(var.name.split('/')[1:]).replace( ':0', '_mask') self.masks += [ tf.get_variable(name, initializer=tf.ones(var.shape), trainable=False) ] var_norm = tf.sqrt( tf.reduce_sum(tf.square(var), axis=[0, 1, 3], keepdims=True)) mask_vec = tf.cast(var_norm > 0.0, tf.float32) mask_new = tf.tile( mask_vec, [var.shape[0], var.shape[1], 1, var.shape[3]]) self.mask_updt_ops += [self.masks[-1].assign(mask_new)] # build extra losses for regression & discrimination self.reg_losses = self.__build_extra_losses() self.nb_layers = len(self.reg_losses) for idx, reg_loss in enumerate(self.reg_losses): tf.summary.scalar('reg_loss_%d' % idx, reg_loss) # obtain the full list of trainable variables & update operations self.vars_all = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_prnd) self.trainable_vars_all = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.model_scope_prnd) self.update_ops_all = tf.get_collection( tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_prnd) # TF operations for initializing the channel-pruned model init_ops = [] with tf.control_dependencies( [tf.variables_initializer(self.vars_all)]): for var_full, var_prnd in zip(self.vars_full['all'], self.vars_prnd['all']): init_ops += [var_prnd.assign(var_full)] self.init_op = tf.group(init_ops) # TF operations for layer-wise, block-wise, and whole-network fine-tuning self.layer_ops, self.lrn_rates_pgd, self.prune_perctls = self.__build_layer_ops( ) self.train_op, self.init_opt_op = self.__build_network_ops( loss, lrn_rate) # TF operations for logging & summarizing self.sess_train = sess self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss, pr_trainable, pr_maskable] + list( metrics.values()) self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list( metrics.keys()) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) def __build_eval(self): """Build the evaluation graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member self.sess_eval = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_eval() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_eval, images) # model definition - channel-pruned model with tf.variable_scope(self.model_scope_prnd): # loss & extra evaluation metrics logits = self.forward_eval(images) vars_prnd = get_vars_by_scope(self.model_scope_prnd) loss, metrics = self.calc_loss(labels, logits, vars_prnd['trainable']) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) # overall pruning ratios of trainable & maskable variables pr_trainable = calc_prune_ratio(vars_prnd['trainable']) pr_maskable = calc_prune_ratio(vars_prnd['maskable']) # TF operations for evaluation self.eval_op = [loss, pr_trainable, pr_maskable] + list( metrics.values()) self.eval_op_names = ['loss', 'pr_trn', 'pr_msk'] + list( metrics.keys()) self.saver_prnd_eval = tf.train.Saver(vars_prnd['all']) # add input & output tensors to certain collections tf.add_to_collection('images_final', images) tf.add_to_collection('logits_final', logits) def __build_extra_losses(self): """Build extra losses for regression. Returns: * reg_losses: list of regression losses (one per layer) """ # insert additional losses to intermediate layers pattern = re.compile('Conv2D$') core_ops_full = get_ops_by_scope_n_pattern(self.model_scope_full, pattern) core_ops_prnd = get_ops_by_scope_n_pattern(self.model_scope_prnd, pattern) reg_losses = [] for core_op_full, core_op_prnd in zip(core_ops_full, core_ops_prnd): reg_losses += [ tf.nn.l2_loss(core_op_full.outputs[0] - core_op_prnd.outputs[0]) ] return reg_losses def __build_layer_ops(self): """Build layer-wise fine-tuning operations. Returns: * layer_ops: list of training and initialization operations for each layer * lrn_rates_pgd: list of layer-wise learning rate * prune_perctls: list of layer-wise pruning percentiles """ layer_ops = [] lrn_rates_pgd = [] # list of layer-wise learning rate prune_perctls = [] # list of layer-wise pruning percentiles for idx, var_prnd in enumerate(self.vars_prnd['maskable']): # create placeholders lrn_rate_pgd = tf.placeholder(tf.float32, shape=[], name='lrn_rate_pgd_%d' % idx) prune_perctl = tf.placeholder(tf.float32, shape=[], name='prune_perctl_%d' % idx) # select channels for the current convolutional layer optimizer = tf.train.GradientDescentOptimizer(lrn_rate_pgd) if FLAGS.enbl_multi_gpu: optimizer = mgw.DistributedOptimizer(optimizer) grads = optimizer.compute_gradients(self.reg_losses[idx], [var_prnd]) with tf.control_dependencies(self.update_ops_all): var_prnd_new = var_prnd - lrn_rate_pgd * grads[0][0] var_norm = tf.sqrt( tf.reduce_sum(tf.square(var_prnd_new), axis=[0, 1, 3], keepdims=True)) threshold = tf.contrib.distributions.percentile( var_norm, prune_perctl) shrk_vec = tf.maximum(1.0 - threshold / var_norm, 0.0) prune_op = var_prnd.assign(var_prnd_new * shrk_vec) # fine-tune with selected channels only optimizer_base = tf.train.AdamOptimizer(FLAGS.cpg_lrn_rate_adam) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) grads_origin = optimizer.compute_gradients(self.reg_losses[idx], [var_prnd]) grads_pruned = self.__calc_grads_pruned(grads_origin) with tf.control_dependencies(self.update_ops_all): finetune_op = optimizer.apply_gradients(grads_pruned) init_opt_op = tf.variables_initializer(optimizer_base.variables()) # append layer-wise operations & variables layer_ops += [{ 'prune': prune_op, 'finetune': finetune_op, 'init_opt': init_opt_op }] lrn_rates_pgd += [lrn_rate_pgd] prune_perctls += [prune_perctl] return layer_ops, lrn_rates_pgd, prune_perctls def __build_network_ops(self, loss, lrn_rate): """Build network training operations. Returns: * train_op: training operation of the whole network * init_opt_op: initialization operation of the whole network's optimizer """ optimizer_base = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) grads_origin = optimizer.compute_gradients(loss, self.trainable_vars_all) grads_pruned = self.__calc_grads_pruned(grads_origin) with tf.control_dependencies(self.update_ops_all): train_op = optimizer.apply_gradients(grads_pruned, global_step=self.global_step) init_opt_op = tf.variables_initializer(optimizer_base.variables()) return train_op, init_opt_op def __calc_grads_pruned(self, grads_origin): """Calculate the mask-pruned gradients. Args: * grads_origin: list of original gradients Returns: * grads_pruned: list of mask-pruned gradients """ grads_pruned = [] for grad in grads_origin: if grad[1].name not in self.maskable_var_names: grads_pruned += [grad] else: idx_mask = self.maskable_var_names.index(grad[1].name) grads_pruned += [(grad[0] * self.masks[idx_mask], grad[1])] return grads_pruned def __choose_channels(self): # pylint: disable=too-many-locals """Choose channels for all convolutional layers.""" # obtain each layer's pruning ratio if FLAGS.cpg_prune_ratio_type == 'uniform': ratio_list = [FLAGS.cpg_prune_ratio] * self.nb_layers if FLAGS.cpg_skip_ht_layers: ratio_list[0] = 0.0 ratio_list[-1] = 0.0 elif FLAGS.cpg_prune_ratio_type == 'list': with open(FLAGS.cpg_prune_ratio_file, 'r') as i_file: i_line = i_file.readline().strip() ratio_list = [float(sub_str) for sub_str in i_line.split(',')] else: raise ValueError('unrecognized pruning ratio type: ' + FLAGS.cpg_prune_ratio_type) # select channels for all convolutional layers nb_workers = mgw.size() if FLAGS.enbl_multi_gpu else 1 nb_iters_layer = int(FLAGS.cpg_nb_iters_layer / nb_workers) for idx_layer in range(self.nb_layers): # skip if no pruning is required if ratio_list[idx_layer] == 0.0: continue if self.is_primary_worker('global'): tf.logging.info('layer #%d: pr = %.2f (target)' % (idx_layer, ratio_list[idx_layer])) tf.logging.info('mask.shape = {}'.format( self.masks[idx_layer].shape)) # select channels for the current convolutional layer time_prev = timer() reg_loss_prev = 0.0 lrn_rate_pgd = FLAGS.cpg_lrn_rate_pgd_init for idx_iter in range(nb_iters_layer): # take a stochastic proximal gradient descent step prune_perctl = ratio_list[idx_layer] * 100.0 * ( idx_iter + 1) / nb_iters_layer __, reg_loss = self.sess_train.run( [ self.layer_ops[idx_layer]['prune'], self.reg_losses[idx_layer] ], feed_dict={ self.lrn_rates_pgd[idx_layer]: lrn_rate_pgd, self.prune_perctls[idx_layer]: prune_perctl }) mask = self.sess_train.run(self.masks[idx_layer]) if self.is_primary_worker('global'): nb_chns_nnz = np.count_nonzero(np.sum(mask, axis=(0, 1, 3))) tf.logging.info( 'iter %d: nnz-chns = %d | loss = %.2e | lr = %.2e | percentile = %.2f' % (idx_iter + 1, nb_chns_nnz, reg_loss, lrn_rate_pgd, prune_perctl)) # adjust the learning rate if reg_loss < reg_loss_prev: lrn_rate_pgd *= FLAGS.cpg_lrn_rate_pgd_incr else: lrn_rate_pgd *= FLAGS.cpg_lrn_rate_pgd_decr reg_loss_prev = reg_loss # fine-tune with selected channels only self.sess_train.run(self.mask_updt_ops[idx_layer]) for idx_iter in range(nb_iters_layer): __, reg_loss = self.sess_train.run([ self.layer_ops[idx_layer]['finetune'], self.reg_losses[idx_layer] ]) mask = self.sess_train.run(self.masks[idx_layer]) if self.is_primary_worker('global'): nb_chns_nnz = np.count_nonzero(np.sum(mask, axis=(0, 1, 3))) tf.logging.info('iter %d: nnz-chns = %d | loss = %.2e' % (idx_iter + 1, nb_chns_nnz, reg_loss)) # re-compute the pruning ratio mask_vec = np.mean(np.square( self.sess_train.run(self.masks[idx_layer])), axis=(0, 1, 3)) prune_ratio = 1.0 - float( np.count_nonzero(mask_vec)) / mask_vec.size if self.is_primary_worker('global'): tf.logging.info('layer #%d: pr = %.2f (actual) | time = %.2f' % (idx_layer, prune_ratio, timer() - time_prev)) # compute overall pruning ratios if self.is_primary_worker('global'): log_rslt = self.sess_train.run(self.log_op) log_str = ' | '.join([ '%s = %.4e' % (name, value) for name, value in zip(self.log_op_names, log_rslt) ]) def __save_model(self, is_train): """Save the current model for training or evaluation. Args: * is_train: whether to save a model for training """ if is_train: save_path = self.saver_prnd_train.save(self.sess_train, FLAGS.cpg_save_path, self.global_step) else: save_path = self.saver_prnd_eval.save(self.sess_eval, FLAGS.cpg_save_path_eval) tf.logging.info('model saved to ' + save_path) def __restore_model(self, is_train): """Restore a model from the latest checkpoint files. Args: * is_train: whether to restore a model for training """ save_path = tf.train.latest_checkpoint( os.path.dirname(FLAGS.cpg_save_path)) if is_train: self.saver_prnd_train.restore(self.sess_train, save_path) else: self.saver_prnd_eval.restore(self.sess_eval, save_path) tf.logging.info('model restored from ' + save_path) def __monitor_progress(self, summary, log_rslt, idx_iter, time_step): """Monitor the training progress. Args: * summary: summary protocol buffer * log_rslt: logging operations' results * idx_iter: index of the training iteration * time_step: time step between two summary operations """ # write summaries for TensorBoard visualization self.sm_writer.add_summary(summary, idx_iter) # compute the training speed speed = FLAGS.batch_size * FLAGS.summ_step / time_step if FLAGS.enbl_multi_gpu: speed *= mgw.size() # display monitored statistics log_str = ' | '.join([ '%s = %.4e' % (name, value) for name, value in zip(self.log_op_names, log_rslt) ]) tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' % (idx_iter + 1, log_str, speed))
class NonUniformQuantLearner(AbstractLearner): # pylint: disable=too-many-instance-attributes ''' Nonuniform quantization for weights and activations ''' def __init__(self, sm_writer, model_helper): # class-independent initialization super(NonUniformQuantLearner, self).__init__(sm_writer, model_helper) # class-dependent initialization if FLAGS.enbl_dst: self.helper_dst = DistillationHelper(sm_writer, model_helper, self.mpi_comm) # initialize class attributes self.ops = {} self.bit_placeholders = {} self.statistics = {} self.__build_train() # for train self.__build_eval() # for eval if self.is_primary_worker('local'): self.download_model() self.auto_barrier() # determine the optimal policy. bit_optimizer = BitOptimizer(self.dataset_name, self.weights, self.statistics, self.bit_placeholders, self.ops, self.layerwise_tune_list, self.sess_train, self.sess_eval, self.saver_train, self.saver_quant, self.saver_eval, self.auto_barrier) self.optimal_w_bit_list, self.optimal_a_bit_list = bit_optimizer.run() self.auto_barrier() def train(self): # initialization self.sess_train.run(self.ops['non_cluster_init']) # mgw_size = int(mgw.size()) if FLAGS.enbl_multi_gpu else 1 total_iters = self.finetune_steps if FLAGS.enbl_warm_start: self.__restore_model( is_train=True) # use the latest model for warm start # NOTE: initialize the clusters after restore weights self.sess_train.run(self.ops['cluster_init'], \ feed_dict={self.bit_placeholders['w_train']: self.optimal_w_bit_list}) feed_dict = {self.bit_placeholders['w_train']: self.optimal_w_bit_list, \ self.bit_placeholders['a_train']: self.optimal_a_bit_list} if FLAGS.enbl_multi_gpu: self.sess_train.run(self.ops['bcast']) time_prev = timer() for idx_iter in range(total_iters): # train the model if (idx_iter + 1) % FLAGS.summ_step != 0: self.sess_train.run(self.ops['train'], feed_dict=feed_dict) else: _, summary, log_rslt = self.sess_train.run([self.ops['train'], \ self.ops['summary'], self.ops['log']], feed_dict=feed_dict) time_prev = self.__monitor_progress(summary, log_rslt, time_prev, idx_iter) # save & evaluate the model at certain steps if (idx_iter + 1) % FLAGS.save_step == 0: self.__save_model() self.evaluate() tf.logging.info("Optimal Weight Quantization:{}".format( self.optimal_w_bit_list)) self.auto_barrier() # save the final model self.__save_model() self.evaluate() def evaluate(self): # early break for non-primary workers if not self.is_primary_worker(): return # evaluate the model self.__restore_model(is_train=False) losses, accuracies = [], [] nb_iters = int( np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval)) # build the quantization bits feed_dict = {self.bit_placeholders['w_eval']: self.optimal_w_bit_list, \ self.bit_placeholders['a_eval']: self.optimal_a_bit_list} for _ in range(nb_iters): eval_rslt = self.sess_eval.run(self.ops['eval'], feed_dict=feed_dict) losses.append(eval_rslt[0]) accuracies.append(eval_rslt[1]) tf.logging.info('loss: {}'.format(np.mean(np.array(losses)))) tf.logging.info('accuracy: {}'.format(np.mean(np.array(accuracies)))) tf.logging.info("Optimal Weight Quantization:{}".format( self.optimal_w_bit_list)) if FLAGS.nuql_use_buckets: bucket_storage = self.sess_eval.run(self.ops['bucket_storage'], feed_dict=feed_dict) self.__show_bucket_storage(bucket_storage) def __build_train(self): with tf.Graph().as_default(): # TensorFlow session config = tf.ConfigProto() config.gpu_options.visible_device_list = str(mgw.local_rank() \ if FLAGS.enbl_multi_gpu else 0) self.sess_train = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() images.set_shape((FLAGS.batch_size, images.shape[1], images.shape[2], \ images.shape[3])) # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_train, images) # model definition with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE): # forward pass logits = self.forward_train(images) self.saver_train = tf.train.Saver(self.vars) self.weights = [ v for v in self.trainable_vars if 'kernel' in v.name or 'weight' in v.name ] if not FLAGS.nuql_quantize_all_layers: self.weights = self.weights[1:-1] self.statistics['num_weights'] = \ [tf.reshape(v, [-1]).shape[0].value for v in self.weights] self.__quantize_train_graph() # loss & accuracy # Strictly speaking, clusters should be not included for regularization. loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if self.dataset_name == 'cifar_10': acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.) elif self.dataset_name == 'ilsvrc_12': acc_top1, acc_top5 = metrics['acc_top1'], metrics[ 'acc_top5'] else: raise ValueError("Unrecognized dataset name") model_loss = loss if FLAGS.enbl_dst: dst_loss = self.helper_dst.calc_loss(logits, logits_dst) loss += dst_loss tf.summary.scalar('dst_loss', dst_loss) tf.summary.scalar('model_loss', model_loss) tf.summary.scalar('loss', loss) tf.summary.scalar('acc_top1', acc_top1) tf.summary.scalar('acc_top5', acc_top5) self.ft_step = tf.get_variable('finetune_step', shape=[], dtype=tf.int32, trainable=False) # optimizer & gradients init_lr, bnds, decay_rates, self.finetune_steps = \ setup_bnds_decay_rates(self.model_name, self.dataset_name) lrn_rate = tf.train.piecewise_constant(self.ft_step, [i for i in bnds], \ [init_lr * decay_rate for decay_rate in decay_rates]) # optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) optimizer = tf.train.AdamOptimizer(lrn_rate) if FLAGS.enbl_multi_gpu: optimizer = mgw.DistributedOptimizer(optimizer) # acquire non-cluster-vars and clusters. clusters = [v for v in self.trainable_vars if 'clusters' in v.name] rest_trainable_vars = [v for v in self.trainable_vars \ if v not in clusters] # determine the var_list optimize if FLAGS.nuql_opt_mode in ['cluster', 'both']: if FLAGS.nuql_opt_mode == 'both': optimizable_vars = self.trainable_vars else: optimizable_vars = clusters if FLAGS.nuql_enbl_rl_agent: optimizer_fintune = tf.train.GradientDescentOptimizer( lrn_rate) if FLAGS.enbl_multi_gpu: optimizer_fintune = mgw.DistributedOptimizer( optimizer_fintune) grads_fintune = optimizer_fintune.compute_gradients( loss, var_list=optimizable_vars) elif FLAGS.nuql_opt_mode == 'weights': optimizable_vars = rest_trainable_vars else: raise ValueError("Unknown optimization mode") grads = optimizer.compute_gradients(loss, var_list=optimizable_vars) # sm write graph self.sm_writer.add_graph(self.sess_train.graph) # define the ops with tf.control_dependencies(self.update_ops): self.ops['train'] = optimizer.apply_gradients( grads, global_step=self.ft_step) if FLAGS.nuql_opt_mode in ['both', 'cluster' ] and FLAGS.nuql_enbl_rl_agent: self.ops['rl_fintune'] = optimizer_fintune.apply_gradients( grads_fintune, global_step=self.ft_step) else: self.ops['rl_fintune'] = self.ops['train'] self.ops['summary'] = tf.summary.merge_all() if FLAGS.enbl_dst: self.ops['log'] = [ lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5 ] else: self.ops['log'] = [ lrn_rate, model_loss, loss, acc_top1, acc_top5 ] # NOTE: first run non_cluster_init_op, then cluster_init_op cluster_global_vars = [ v for v in self.vars if 'clusters' in v.name ] # pay attention to beta1_power, beta2_power. non_cluster_global_vars = [ v for v in tf.global_variables() if v not in cluster_global_vars ] self.ops['non_cluster_init'] = tf.variables_initializer( var_list=non_cluster_global_vars) self.ops['cluster_init'] = tf.variables_initializer( var_list=cluster_global_vars) self.ops['bcast'] = mgw.broadcast_global_variables(0) \ if FLAGS.enbl_multi_gpu else None self.ops['reset_ft_step'] = tf.assign(self.ft_step, \ tf.constant(0, dtype=tf.int32)) self.saver_quant = tf.train.Saver(self.vars) def __build_eval(self): with tf.Graph().as_default(): # TensorFlow session config = tf.ConfigProto() config.gpu_options.visible_device_list = str(mgw.local_rank() \ if FLAGS.enbl_multi_gpu else 0) self.sess_eval = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_eval() images, labels = iterator.get_next() images.set_shape((FLAGS.batch_size, images.shape[1], images.shape[2], \ images.shape[3])) self.images_eval = images # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_eval, images) # model definition with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE): # forward pass logits = self.forward_eval(images) self.__quantize_eval_graph() # loss & accuracy loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if self.dataset_name == 'cifar_10': acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.) elif self.dataset_name == 'ilsvrc_12': acc_top1, acc_top5 = metrics['acc_top1'], metrics[ 'acc_top5'] else: raise ValueError("Unrecognized dataset name") if FLAGS.enbl_dst: dst_loss = self.helper_dst.calc_loss(logits, logits_dst) loss += dst_loss self.ops['eval'] = [loss, acc_top1, acc_top5] self.saver_eval = tf.train.Saver(self.vars) def __quantize_train_graph(self): """ Insert Quantization nodes to the training graph """ nonuni_quant = NonUniformQuantization(self.sess_train, FLAGS.nuql_bucket_size, FLAGS.nuql_use_buckets, FLAGS.nuql_init_style, FLAGS.nuql_bucket_type) # Find Matmul Op and activation Op matmul_ops = nonuni_quant.search_matmul_op( FLAGS.nuql_quantize_all_layers) act_ops = nonuni_quant.search_activation_op() self.statistics['nb_matmuls'] = len(matmul_ops) self.statistics['nb_activations'] = len(act_ops) # Replace Conv2d Op with quantized weights matmul_op_names = [op.name for op in matmul_ops] act_op_names = [op.name for op in act_ops] # build the placeholder for nonuniform quantization bits. self.bit_placeholders['w_train'] = tf.placeholder(tf.int64, \ shape=[self.statistics['nb_matmuls']], name="w_bit_list") self.bit_placeholders['a_train'] = tf.placeholder(tf.int64, \ shape=[self.statistics['nb_activations']], name="a_bit_list") w_bit_dict_train = self.__build_quant_dict(matmul_op_names, \ self.bit_placeholders['w_train']) a_bit_dict_train = self.__build_quant_dict(act_op_names, \ self.bit_placeholders['a_train']) # Insert Quant Op for weights and activations nonuni_quant.insert_quant_op_for_weights(w_bit_dict_train) # NOTE: Not necessary for activation quantization in non-uniform nonuni_quant.insert_quant_op_for_activations(a_bit_dict_train) # TODO: add layerwise finetuning. working not very weill self.layerwise_tune_list = nonuni_quant.get_layerwise_tune_op(self.weights) \ if FLAGS.nuql_enbl_rl_layerwise_tune else (None, None) def __quantize_eval_graph(self): """ Insert Quantization nodes to the eval graph """ nonuni_quant = NonUniformQuantization(self.sess_eval, FLAGS.nuql_bucket_size, FLAGS.nuql_use_buckets, FLAGS.nuql_init_style, FLAGS.nuql_bucket_type) # Find Matmul Op and activation Op matmul_ops = nonuni_quant.search_matmul_op( FLAGS.nuql_quantize_all_layers) act_ops = nonuni_quant.search_activation_op() assert self.statistics['nb_matmuls'] == len(matmul_ops), \ 'the length of matmul_ops does not match' assert self.statistics['nb_activations'] == len(act_ops), \ 'the length of act_ops does not match' # Replace Conv2d Op with quantized weights matmul_op_names = [op.name for op in matmul_ops] act_op_names = [op.name for op in act_ops] # build the placeholder for eval self.bit_placeholders['w_eval'] = tf.placeholder(tf.int64, \ shape=[self.statistics['nb_matmuls']], name="w_bit_list") self.bit_placeholders['a_eval'] = tf.placeholder(tf.int64, \ shape=[self.statistics['nb_activations']], name="a_bit_list") w_bit_dict_eval = self.__build_quant_dict(matmul_op_names, \ self.bit_placeholders['w_eval']) a_bit_dict_eval = self.__build_quant_dict(act_op_names, \ self.bit_placeholders['a_eval']) # Insert Quant Op for weights and activations nonuni_quant.insert_quant_op_for_weights(w_bit_dict_eval) # NOTE: no need for activation quantization nonuni_quant.insert_quant_op_for_activations(a_bit_dict_eval) self.ops['bucket_storage'] = nonuni_quant.bucket_storage if FLAGS.nuql_use_buckets \ else tf.constant(0, tf.int32) def __save_model(self): # early break for non-primary workers if not self.is_primary_worker(): return # save quantization model save_quant_model_path = self.saver_quant.save(self.sess_train, \ FLAGS.nuql_save_quant_model_path, self.ft_step) tf.logging.info('quantized model saved to ' + save_quant_model_path) def __restore_model(self, is_train): if is_train: save_path = tf.train.latest_checkpoint( os.path.dirname(FLAGS.save_path)) self.saver_train.restore(self.sess_train, save_path) else: save_path = tf.train.latest_checkpoint( os.path.dirname(FLAGS.nuql_save_quant_model_path)) self.saver_eval.restore(self.sess_eval, save_path) tf.logging.info('model restored from ' + save_path) def __monitor_progress(self, summary, log_rslt, time_prev, idx_iter): # early break for non-primary workers if not self.is_primary_worker(): return None # write summaries for TensorBoard visualization self.sm_writer.add_summary(summary, idx_iter) # display monitored statistics speed = FLAGS.batch_size * FLAGS.summ_step / (timer() - time_prev) if FLAGS.enbl_multi_gpu: speed *= mgw.size() if FLAGS.enbl_dst: lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5 = log_rslt[0], \ log_rslt[1], log_rslt[2], log_rslt[3], log_rslt[4], log_rslt[5] tf.logging.info('iter #%d: lr = %e | dst_loss = %.4f | model_loss = %.4f | loss = %.4f | acc_top1 = %.4f | acc_top5 = %.4f | speed = %.2f pics / sec' \ % (idx_iter + 1, lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5, speed)) else: lrn_rate, model_loss, loss, acc_top1, acc_top5 = log_rslt[0], \ log_rslt[1], log_rslt[2], log_rslt[3], log_rslt[4] tf.logging.info('iter #%d: lr = %e | model_loss = %.4f | loss = %.4f | acc_top1 = %.4f | acc_top5 = %.4f | speed = %.2f pics / sec' \ % (idx_iter + 1, lrn_rate, model_loss, loss, acc_top1, acc_top5, speed)) return timer() def __show_bucket_storage(self, bucket_storage): weight_storage = sum(self.statistics['num_weights']) * FLAGS.nuql_weight_bits \ if not FLAGS.nuql_enbl_rl_agent \ else sum(self.statistics['num_weights']) * FLAGS.nuql_equivalent_bits tf.logging.info('bucket storage: %d bit / %.3f kb | weight storage: %d bit / %.3f kb | ratio: %.3f' % \ (bucket_storage, bucket_storage/(8.*1024.), weight_storage, \ weight_storage/(8.*1024.), bucket_storage*1./weight_storage)) @staticmethod def __build_quant_dict(keys, values): """ Bind keys and values to dictionaries. Args: * keys: A list of op_names; * values: A Tensor with len(keys) elements; Returns: * dict: (key, value) for weight name and quant bits respectively. """ dict_ = {} for (idx, v) in enumerate(keys): dict_[v] = values[idx] return dict_
class ChannelPrunedLearner(AbstractLearner): # pylint: disable=too-many-instance-attributes """Learner with channel/filter pruning""" def __init__(self, sm_writer, model_helper): # class-independent initialization super(ChannelPrunedLearner, self).__init__(sm_writer, model_helper) # class-dependent initialization if FLAGS.enbl_dst: self.learner_dst = DistillationHelper(sm_writer, model_helper, self.mpi_comm) self.model_scope = 'model' self.sm_writer = sm_writer #self.max_eval_acc = 0 self.max_save_path = '' self.saver = None self.saver_train = None self.saver_eval = None self.model = None self.pruner = None self.sess_train = None self.sess_eval = None self.log_op = None self.train_op = None self.bcast_op = None self.train_init_op = None self.time_prev = None self.agent = None self.idx_iter = None self.accuracy_keys = None self.eval_op = None self.global_step = None self.summary_op = None self.nb_iters_train = 0 self.bestinfo = None self.__build(is_train=True) self.__build(is_train=False) def train(self): """Train the pruned model""" # download pre-trained model if self.__is_primary_worker(): self.download_model() self.__restore_model(True) self.saver_train.save(self.sess_train, FLAGS.cp_original_path) self.create_pruner() if FLAGS.enbl_multi_gpu: self.mpi_comm.Barrier() tf.logging.info('Start pruning') # channel pruning and finetuning if FLAGS.cp_prune_option == 'list': self.__prune_and_finetune_list() elif FLAGS.cp_prune_option == 'auto': self.__prune_and_finetune_auto() elif FLAGS.cp_prune_option == 'uniform': self.__prune_and_finetune_uniform() def create_pruner(self): """create a pruner""" with tf.Graph().as_default(): config = tf.ConfigProto() config.gpu_options.visible_device_list = str(0) # pylint: disable=no-member sess = tf.Session(config=config) self.saver = tf.train.import_meta_graph(FLAGS.cp_original_path + '.meta') self.saver.restore(sess, FLAGS.cp_original_path) self.sess_train = sess self.sm_writer.add_graph(sess.graph) train_images = tf.get_collection('train_images')[0] train_labels = tf.get_collection('train_labels')[0] mem_images = tf.get_collection('mem_images')[0] mem_labels = tf.get_collection('mem_labels')[0] summary_op = tf.get_collection('summary_op')[0] loss = tf.get_collection('loss')[0] accuracy = tf.get_collection('accuracy')[0] #accuracy1 = tf.get_collection('top1')[0] #metrics = {'loss': loss, 'accuracy': accuracy['top1']} metrics = {'loss': loss, 'accuracy': accuracy} for key in self.accuracy_keys: metrics[key] = tf.get_collection(key)[0] self.model = Model(self.sess_train) pruner = ChannelPruner( self.model, images=train_images, labels=train_labels, mem_images=mem_images, mem_labels=mem_labels, metrics=metrics, lbound=self.lbound, summary_op=summary_op, sm_writer=self.sm_writer) self.pruner = pruner def evaluate(self): """evaluate the model""" # early break for non-primary workers if not self.__is_primary_worker(): return if self.saver_eval is None: self.saver_eval = tf.train.Saver() self.__restore_model(is_train=False) losses, accuracy = [], [] nb_iters = FLAGS.nb_smpls_eval // FLAGS.batch_size_eval self.sm_writer.add_graph(self.sess_eval.graph) accuracies = [[] for i in range(len(self.accuracy_keys))] for _ in range(nb_iters): eval_rslt = self.sess_eval.run(self.eval_op) losses.append(eval_rslt[0]) for i in range(len(self.accuracy_keys)): accuracies[i].append(eval_rslt[i + 1]) loss = np.mean(np.array(losses)) tf.logging.info('loss: {}'.format(loss)) for i in range(len(self.accuracy_keys)): accuracy.append(np.mean(np.array(accuracies[i]))) tf.logging.info('{}: {}'.format(self.accuracy_keys[i], accuracy[i])) # save the checkpoint if its evaluatin result is best so far #if accuracy[0] > self.max_eval_acc: # self.max_eval_acc = accuracy[0] # self.__save_in_progress_pruned_model() def __build(self, is_train): # pylint: disable=too-many-locals # early break for non-primary workers if not self.__is_primary_worker(): return if not is_train: self.__build_pruned_evaluate_model() return with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str(0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): train_images, train_labels = self.build_dataset_train().get_next() eval_images, eval_labels = self.build_dataset_eval().get_next() image_shape = train_images.shape.as_list() label_shape = train_labels.shape.as_list() image_shape[0] = FLAGS.batch_size label_shape[0] = FLAGS.batch_size mem_images = tf.placeholder(dtype=train_images.dtype, shape=image_shape) mem_labels = tf.placeholder(dtype=train_labels.dtype, shape=label_shape) tf.add_to_collection('train_images', train_images) tf.add_to_collection('train_labels', train_labels) tf.add_to_collection('eval_images', eval_images) tf.add_to_collection('eval_labels', eval_labels) tf.add_to_collection('mem_images', mem_images) tf.add_to_collection('mem_labels', mem_labels) # model definition with tf.variable_scope(self.model_scope): # forward pass logits = self.forward_train(mem_images) loss, accuracy = self.calc_loss(mem_labels, logits, self.trainable_vars) self.accuracy_keys = list(accuracy.keys()) for key in self.accuracy_keys: tf.add_to_collection(key, accuracy[key]) tf.add_to_collection('loss', loss) tf.add_to_collection('logits', logits) #self.loss = loss tf.summary.scalar('loss', loss) for key in accuracy.keys(): tf.summary.scalar(key, accuracy[key]) # learning rate & pruning ratio self.sess_train = sess self.summary_op = tf.summary.merge_all() tf.add_to_collection('summary_op', self.summary_op) self.saver_train = tf.train.Saver(self.vars) self.lbound = math.log(FLAGS.cp_preserve_ratio + 1, 10) * 1.5 self.rbound = 1.0 def __build_pruned_evaluate_model(self, path=None): ''' build a evaluation model from pruned model ''' # early break for non-primary workers if not self.__is_primary_worker(): return if path is None: path = FLAGS.save_path if not tf.train.checkpoint_exists(path): return with tf.Graph().as_default(): config = tf.ConfigProto() config.gpu_options.visible_device_list = str(# pylint: disable=no-member mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) self.sess_eval = tf.Session(config=config) self.saver_eval = tf.train.import_meta_graph(path + '.meta') self.saver_eval.restore(self.sess_eval, path) eval_logits = tf.get_collection('logits')[0] tf.add_to_collection('logits_final', eval_logits) eval_images = tf.get_collection('eval_images')[0] tf.add_to_collection('images_final', eval_images) eval_labels = tf.get_collection('eval_labels')[0] mem_images = tf.get_collection('mem_images')[0] mem_labels = tf.get_collection('mem_labels')[0] self.sess_eval.close() graph_editor.reroute_ts(eval_images, mem_images) graph_editor.reroute_ts(eval_labels, mem_labels) self.sess_eval = tf.Session(config=config) self.saver_eval.restore(self.sess_eval, path) trainable_vars = self.trainable_vars loss, accuracy = self.calc_loss(eval_labels, eval_logits, trainable_vars) self.eval_op = [loss] + list(accuracy.values()) self.sm_writer.add_graph(self.sess_eval.graph) def __build_pruned_train_model(self, path=None, finetune=False): # pylint: disable=too-many-locals ''' build a training model from pruned model ''' if path is None: path = FLAGS.save_path with tf.Graph().as_default(): config = tf.ConfigProto() config.gpu_options.visible_device_list = str(# pylint: disable=no-member mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) self.sess_train = tf.Session(config=config) self.saver_train = tf.train.import_meta_graph(path + '.meta') self.saver_train.restore(self.sess_train, path) logits = tf.get_collection('logits')[0] train_images = tf.get_collection('train_images')[0] train_labels = tf.get_collection('train_labels')[0] mem_images = tf.get_collection('mem_images')[0] mem_labels = tf.get_collection('mem_labels')[0] self.sess_train.close() graph_editor.reroute_ts(train_images, mem_images) graph_editor.reroute_ts(train_labels, mem_labels) self.sess_train = tf.Session(config=config) self.saver_train.restore(self.sess_train, path) trainable_vars = self.trainable_vars loss, accuracy = self.calc_loss(train_labels, logits, trainable_vars) self.accuracy_keys = list(accuracy.keys()) if FLAGS.enbl_dst: logits_dst = self.learner_dst.calc_logits(self.sess_train, train_images) loss += self.learner_dst.calc_loss(logits, logits_dst) tf.summary.scalar('loss', loss) for key in accuracy.keys(): tf.summary.scalar(key, accuracy[key]) self.summary_op = tf.summary.merge_all() global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, trainable=False) self.global_step = global_step lrn_rate, self.nb_iters_train = setup_lrn_rate( self.global_step, self.model_name, self.dataset_name) if finetune and not FLAGS.cp_retrain: mom_optimizer = tf.train.AdamOptimizer(FLAGS.cp_lrn_rate_ft) self.log_op = [tf.constant(FLAGS.cp_lrn_rate_ft), loss, list(accuracy.values())] else: mom_optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) self.log_op = [lrn_rate, loss, list(accuracy.values())] if FLAGS.enbl_multi_gpu: optimizer = mgw.DistributedOptimizer(mom_optimizer) else: optimizer = mom_optimizer grads_origin = optimizer.compute_gradients(loss, trainable_vars) grads_pruned, masks = self.__calc_grads_pruned(grads_origin) with tf.control_dependencies(self.update_ops): self.train_op = optimizer.apply_gradients(grads_pruned, global_step=global_step) self.sm_writer.add_graph(tf.get_default_graph()) self.train_init_op = \ tf.initialize_variables(mom_optimizer.variables() + [global_step] + masks) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) def __calc_grads_pruned(self, grads_origin): """Calculate the pruned gradients Args: * grads_origin: the original gradient Return: * the pruned gradients * the corresponding mask of the pruned gradients """ grads_pruned = [] masks = [] maskable_var_names = {} fake_pruning_dict = {} if self.__is_primary_worker(): fake_pruning_dict = self.pruner.fake_pruning_dict maskable_var_names = { self.pruner.model.get_var_by_op( self.pruner.model.g.get_operation_by_name(op_name)).name: \ op_name for op_name, ratio in fake_pruning_dict.items()} tf.logging.debug('maskable var names {}'.format(maskable_var_names)) if FLAGS.enbl_multi_gpu: fake_pruning_dict = self.mpi_comm.bcast(fake_pruning_dict, root=0) maskable_var_names = self.mpi_comm.bcast(maskable_var_names, root=0) for grad in grads_origin: if grad[1].name not in maskable_var_names.keys(): grads_pruned.append(grad) else: pruned_idxs = fake_pruning_dict[maskable_var_names[grad[1].name]] mask_tensor = np.ones(grad[0].shape) mask_tensor[:, :, [not i for i in pruned_idxs[0]], :] = 0 mask_tensor[:, :, :, [not i for i in pruned_idxs[1]]] = 0 mask_initializer = tf.constant_initializer(mask_tensor) mask = tf.get_variable( grad[1].name.split(':')[0] + '_mask', shape=mask_tensor.shape, initializer=mask_initializer, trainable=False) masks.append(mask) grads_pruned.append((grad[0] * mask, grad[1])) return grads_pruned, masks def __train_pruned_model(self, finetune=False): """Train pruned model""" # Initialize varialbes self.sess_train.run(self.train_init_op) if FLAGS.enbl_multi_gpu: self.sess_train.run(self.bcast_op) ## Fintuning & distilling self.time_prev = timer() nb_iters = int(FLAGS.cp_nb_iters_ft_ratio * self.nb_iters_train) \ if finetune and not FLAGS.cp_retrain else self.nb_iters_train for self.idx_iter in range(nb_iters): # train the model if (self.idx_iter + 1) % FLAGS.summ_step != 0: self.sess_train.run(self.train_op) else: __, summary, log_rslt = self.sess_train.run([self.train_op, self.summary_op, self.log_op]) self.__monitor_progress(summary, log_rslt) # save the model at certain steps if (self.idx_iter + 1) % FLAGS.save_step == 0: #summary, log_rslt = self.sess_train.run([self.summary_op, self.log_op]) #self.__monitor_progress(summary, log_rslt) if self.__is_primary_worker(): self.__save_model() self.evaluate() if FLAGS.enbl_multi_gpu: self.mpi_comm.Barrier() if self.__is_primary_worker(): self.__save_model() self.evaluate() self.__save_in_progress_pruned_model() if FLAGS.enbl_multi_gpu: self.max_save_path = self.mpi_comm.bcast(self.max_save_path, root=0) if self.__is_primary_worker(): with self.pruner.model.g.as_default(): #save_path = tf.train.latest_checkpoint(os.path.dirname(FLAGS.channel_pruned_path)) self.pruner.saver = tf.train.Saver() self.pruner.saver.restore(self.pruner.model.sess, self.max_save_path) #self.pruner.save_model() #self.saver_train.restore(self.sess_train, self.max_save_path) #self.__save_model() def __save_best_pruned_model(self): """ save a in best purned model with a max evaluation result""" best_path = tf.train.Saver().save(self.pruner.model.sess, FLAGS.cp_best_path) tf.logging.info('model saved best model to ' + best_path) def __save_in_progress_pruned_model(self): """ save a in progress training model with a max evaluation result""" self.max_save_path = self.saver_eval.save(self.sess_eval, FLAGS.cp_best_path) tf.logging.info('model saved best model to ' + self.max_save_path) def __save_model(self): save_path = self.saver_train.save(self.sess_train, FLAGS.save_path, self.global_step) tf.logging.info('model saved to ' + save_path) def __restore_model(self, is_train): save_path = tf.train.latest_checkpoint(os.path.dirname(FLAGS.save_path)) if is_train: self.saver_train.restore(self.sess_train, save_path) else: self.saver_eval.restore(self.sess_eval, save_path) tf.logging.info('model restored from ' + save_path) def __monitor_progress(self, summary, log_rslt): # early break for non-primary workers if not self.__is_primary_worker(): return # write summaries for TensorBoard visualization self.sm_writer.add_summary(summary, self.idx_iter) # display monitored statistics lrn_rate, loss, accuracy = log_rslt[0], log_rslt[1], log_rslt[2] speed = FLAGS.batch_size * FLAGS.summ_step / (timer() - self.time_prev) if FLAGS.enbl_multi_gpu: speed *= mgw.size() tf.logging.info('iter #%d: lr = %e | loss = %e | speed = %.2f pics / sec' % (self.idx_iter + 1, lrn_rate, loss, speed)) for i in range(len(self.accuracy_keys)): tf.logging.info('{} = {}'.format(self.accuracy_keys[i], accuracy[i])) self.time_prev = timer() def __prune_and_finetune_uniform(self): '''prune with a list of compression ratio''' if self.__is_primary_worker(): done = False self.pruner.extract_features() start = timer() while not done: _, _, done, _ = self.pruner.compress(FLAGS.cp_uniform_preserve_ratio) tf.logging.info('uniform channl pruning time cost: {}s'.format(timer() - start)) self.pruner.save_model() if FLAGS.enbl_multi_gpu: self.mpi_comm.Barrier() self.__finetune_pruned_model(path=FLAGS.cp_channel_pruned_path) def __prune_and_finetune_list(self): '''prune with a list of compression ratio''' try: ratio_list = np.loadtxt(FLAGS.cp_prune_list_file, delimiter=',') ratio_list = list(ratio_list) except IOError as err: tf.logging.error('The prune list file format is not correct. \n \ It\'s content should be a float list delimited by a comma.') raise err ratio_list.reverse() queue = deque(ratio_list) done = False while not done: done = self.__prune_list_layers(queue, [FLAGS.cp_list_group]) def __prune_list_layers(self, queue, ps=None): for p in ps: done = self.__prune_n_layers(p, queue) return done def __prune_n_layers(self, n, queue): #self.max_eval_acc = 0 done = False if self.__is_primary_worker(): self.pruner.extract_features() done = False i = 0 while not done and i < n: if not queue: ratio = 1 else: ratio = queue.pop() _, _, done, _ = self.pruner.compress(ratio) i += 1 self.pruner.save_model() if FLAGS.enbl_multi_gpu: self.mpi_comm.Barrier() done = self.mpi_comm.bcast(done, root=0) if done: self.__finetune_pruned_model(path=FLAGS.cp_channel_pruned_path, finetune=False) else: self.__finetune_pruned_model(path=FLAGS.cp_channel_pruned_path, finetune=FLAGS.cp_finetune) return done def __finetune_pruned_model(self, path=None, finetune=False): if path is None: path = FLAGS.cp_channel_pruned_path start = timer() tf.logging.info('build pruned evaluating model') self.__build_pruned_evaluate_model(path) tf.logging.info('build pruned training model') self.__build_pruned_train_model(path, finetune=finetune) tf.logging.info('training pruned model') self.__train_pruned_model(finetune=finetune) tf.logging.info('fintuning time cost: {}s'.format(timer() - start)) def __prune_and_finetune_auto(self): if self.__is_primary_worker(): self.__prune_rl() self.pruner.initialize_state() if FLAGS.enbl_multi_gpu: self.mpi_comm.Barrier() self.bestinfo = self.mpi_comm.bcast(self.bestinfo, root=0) ratio_list = self.bestinfo[0] tf.logging.info('best split ratio is: {}'.format(ratio_list)) ratio_list.reverse() queue = deque(ratio_list) done = False while not done: done = self.__prune_list_layers(queue, [FLAGS.cp_list_group]) @classmethod def __calc_reward(cls, accuracy, flops): if FLAGS.cp_reward_policy == 'accuracy': reward = accuracy * np.ones((1, 1)) elif FLAGS.cp_reward_policy == 'flops': reward = -np.maximum( FLAGS.cp_noise_tolerance, (1 - accuracy)) * np.log(flops) * np.ones((1, 1)) else: raise ValueError('unrecognized reward type: ' + FLAGS.cp_reward_policy) return reward def __prune_rl(self): # pylint: disable=too-many-locals """ search pruning strategy with reinforcement learning""" tf.logging.info( 'preserve lower bound: {}, preserve ratio: {}, preserve upper bound: {}'.format( self.lbound, FLAGS.cp_preserve_ratio, self.rbound)) config = tf.ConfigProto() config.gpu_options.visible_device_list = str(0) # pylint: disable=no-member buf_size = len(self.pruner.states) * FLAGS.cp_nb_rlouts_min nb_rlouts = FLAGS.cp_nb_rlouts self.agent = DdpgAgent( tf.Session(config=config), len(self.pruner.states.loc[0].tolist()), 1, nb_rlouts, buf_size, self.lbound, self.rbound) self.agent.init() self.bestinfo = None reward_best = np.NINF # pylint: disable=no-member for idx_rlout in range(FLAGS.cp_nb_rlouts): # execute roll-outs to obtain pruning ratios self.agent.init_rlout() states_n_actions = [] self.create_pruner() self.pruner.initialize_state() self.pruner.extract_features() state = np.array(self.pruner.currentStates.loc[0].tolist())[None, :] start = timer() while True: tf.logging.info('state is {}'.format(state)) action = self.agent.sess.run(self.agent.actions_noisy, feed_dict={self.agent.states: state}) tf.logging.info('RL choosed preserv ratio: {}'.format(action)) state_next, acc_flops, done, real_action = self.pruner.compress(action) tf.logging.info('Actural preserv ratio: {}'.format(real_action)) states_n_actions += [(state, real_action * np.ones((1, 1)))] state = state_next[None, :] actor_loss, critic_loss, noise_std = self.agent.train() if done: break tf.logging.info('roll-out #%d: a-loss = %.2e | c-loss = %.2e | noise std. = %.2e' % (idx_rlout, actor_loss, critic_loss, noise_std)) reward = self.__calc_reward(acc_flops[0], acc_flops[1]) rewards = reward * np.ones(len(self.pruner.states)) self.agent.finalize_rlout(rewards) # record transactions for RL training strategy = [] for idx, (state, action) in enumerate(states_n_actions): strategy.append(action[0, 0]) if idx != len(states_n_actions) - 1: terminal = np.zeros((1, 1)) state_next = states_n_actions[idx + 1][0] else: terminal = np.ones((1, 1)) state_next = np.zeros_like(state) self.agent.record(state, action, reward, terminal, state_next) # record the best combination of pruning ratios if reward_best < reward: tf.logging.info('best reward updated: %.4f -> %.4f' % (reward_best, reward)) reward_best = reward self.bestinfo = [strategy, acc_flops[0], acc_flops[1]] tf.logging.info("""The best pruned model occured with strategy: {}, accuracy: {} and pruned ratio: {}""".format(self.bestinfo[0], self.bestinfo[1], self.bestinfo[2])) tf.logging.info('automatic channl pruning time cost: {}s'.format(timer() - start)) @classmethod def __is_primary_worker(cls): """Weather it is the primary worker""" return not FLAGS.enbl_multi_gpu or mgw.rank() == 0
class ChannelPrunedRmtLearner(AbstractLearner): # pylint: disable=too-many-instance-attributes """Channel pruning learner - remastered.""" def __init__(self, sm_writer, model_helper): """Constructor function. Args: * sm_writer: TensorFlow's summary writer * model_helper: model helper with definitions of model & dataset """ # class-independent initialization super(ChannelPrunedRmtLearner, self).__init__(sm_writer, model_helper) # define scopes for full & channel-pruned models self.model_scope_full = 'model' self.model_scope_prnd = 'pruned_model' # download the pre-trained model if self.is_primary_worker('local'): self.download_model() # pre-trained model is required self.auto_barrier() tf.logging.info('model files: ' + ', '.join(os.listdir('./models'))) # class-dependent initialization if FLAGS.enbl_dst: self.helper_dst = DistillationHelper(sm_writer, model_helper, self.mpi_comm) self.__build_train() self.__build_eval() def train(self): """Train a model and periodically produce checkpoint files.""" # restore the full model from pre-trained checkpoints save_path = tf.train.latest_checkpoint( os.path.dirname(self.save_path_full)) self.saver_full.restore(self.sess_train, save_path) # initialization self.sess_train.run(self.init_op) if FLAGS.enbl_multi_gpu: self.sess_train.run(self.bcast_op) # choose channels and evaluate the model before re-training time_prev = timer() self.__choose_channels() tf.logging.info('time (channel selection): %.2f (s)' % (timer() - time_prev)) self.sess_train.run(self.mask_updt_op) if FLAGS.enbl_multi_gpu: self.sess_train.run(self.bcast_op) # evaluate the model before fine-tuning if self.is_primary_worker('global'): self.__save_model(is_train=True) self.evaluate() self.auto_barrier() # fine-tune the model with chosen channels only time_prev = timer() for idx_iter in range(self.nb_iters_train): # train the model if (idx_iter + 1) % FLAGS.summ_step != 0: self.sess_train.run(self.train_op) else: __, summary, log_rslt = self.sess_train.run( [self.train_op, self.summary_op, self.log_op]) if self.is_primary_worker('global'): time_step = timer() - time_prev self.__monitor_progress(summary, log_rslt, idx_iter, time_step) time_prev = timer() # save the model at certain steps if self.is_primary_worker('global') and (idx_iter + 1) % FLAGS.save_step == 0: self.__save_model(is_train=True) self.evaluate() self.auto_barrier() # save the final model if self.is_primary_worker('global'): self.__save_model(is_train=True) self.__restore_model(is_train=False) self.__save_model(is_train=False) self.evaluate() def evaluate(self): """Restore a model from the latest checkpoint files and then evaluate it.""" self.__restore_model(is_train=False) nb_iters = int( np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval)) eval_rslts = np.zeros((nb_iters, len(self.eval_op))) self.dump_n_eval(outputs=None, action='init') for idx_iter in range(nb_iters): if (idx_iter + 1) % 100 == 0: tf.logging.info('process the %d-th mini-batch for evaluation' % (idx_iter + 1)) eval_rslts[idx_iter], outputs = self.sess_eval.run( [self.eval_op, self.outputs_eval]) self.dump_n_eval(outputs=outputs, action='dump') self.dump_n_eval(outputs=None, action='eval') for idx, name in enumerate(self.eval_op_names): tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx]))) def __build_train(self): # pylint: disable=too-many-locals,too-many-statements """Build the training graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.allow_growth = True # pylint: disable=no-member config.gpu_options.visible_device_list = \ str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - full model with tf.variable_scope(self.model_scope_full): __ = self.forward_train(images) self.vars_full = get_vars_by_scope(self.model_scope_full) self.saver_full = tf.train.Saver(self.vars_full['all']) self.save_path_full = FLAGS.save_path # model definition - channel-pruned model with tf.variable_scope(self.model_scope_prnd): logits_prnd = self.forward_train(images) self.vars_prnd = get_vars_by_scope(self.model_scope_prnd) self.conv_krnl_var_names = [ var.name for var in self.vars_prnd['conv_krnl'] ] self.global_step = tf.train.get_or_create_global_step() self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'] + [self.global_step]) # loss & extra evaluation metrics loss, metrics = self.calc_loss(labels, logits_prnd, self.vars_prnd['trainable']) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits_prnd, logits_dst) tf.summary.scalar('loss', loss) for key, value in metrics.items(): tf.summary.scalar(key, value) # learning rate schedule lrn_rate, self.nb_iters_train = self.setup_lrn_rate( self.global_step) # calculate pruning ratios pr_trainable = calc_prune_ratio(self.vars_prnd['trainable']) pr_conv_krnl = calc_prune_ratio(self.vars_prnd['conv_krnl']) tf.summary.scalar('pr_trainable', pr_trainable) tf.summary.scalar('pr_conv_krnl', pr_conv_krnl) # create masks and corresponding operations for channel pruning self.masks = [] mask_updt_ops = [ ] # update the mask based on convolutional kernel's value for idx, var in enumerate(self.vars_prnd['conv_krnl']): tf.logging.info( 'creating a pruning mask for {} of size {}'.format( var.name, var.shape)) mask_name = '/'.join(var.name.split('/')[1:]).replace( ':0', '_mask') mask_shape = [1, 1, var.shape[2], 1] # 1 x 1 x c_in x 1 mask = tf.get_variable(mask_name, initializer=tf.ones(mask_shape), trainable=False) var_norm = tf.reduce_sum(tf.square(var), axis=[0, 1, 3], keepdims=True) self.masks += [mask] mask_updt_ops += [ mask.assign(tf.cast(var_norm > 0.0, tf.float32)) ] self.mask_updt_op = tf.group(mask_updt_ops) # build operations for channel selection self.__build_chn_select_ops() # optimizer & gradients optimizer_base = tf.train.MomentumOptimizer( lrn_rate, FLAGS.momentum) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) grads_origin = optimizer.compute_gradients( loss, self.vars_prnd['trainable']) grads_pruned = self.__calc_grads_pruned(grads_origin) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_prnd) with tf.control_dependencies(update_ops): self.train_op = optimizer.apply_gradients( grads_pruned, global_step=self.global_step) # TF operations for initializing the channel-pruned model init_ops = [] for var_full, var_prnd in zip(self.vars_full['all'], self.vars_prnd['all']): init_ops += [var_prnd.assign(var_full)] init_ops += [self.global_step.initializer ] # initialize the global step init_ops += [ tf.variables_initializer(optimizer_base.variables()) ] self.init_op = tf.group(init_ops) # TF operations for logging & summarizing self.sess_train = sess self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss, pr_trainable, pr_conv_krnl] + list( metrics.values()) self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_krn'] + list( metrics.keys()) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) def __build_eval(self): """Build the evaluation graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.allow_growth = True # pylint: disable=no-member config.gpu_options.visible_device_list = \ str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member self.sess_eval = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_eval() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_eval, images) # model definition - channel-pruned model with tf.variable_scope(self.model_scope_prnd): logits = self.forward_eval(images) vars_prnd = get_vars_by_scope(self.model_scope_prnd) global_step = tf.train.get_or_create_global_step() self.saver_prnd_eval = tf.train.Saver(vars_prnd['all'] + [global_step]) # loss & extra evaluation metrics loss, metrics = self.calc_loss(labels, logits, vars_prnd['trainable']) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) # calculate pruning ratios pr_trainable = calc_prune_ratio(vars_prnd['trainable']) pr_conv_krnl = calc_prune_ratio(vars_prnd['conv_krnl']) # TF operations for evaluation self.eval_op = [loss, pr_trainable, pr_conv_krnl] + list( metrics.values()) self.eval_op_names = ['loss', 'pr_trn', 'pr_krn'] + list( metrics.keys()) self.outputs_eval = logits # add input & output tensors to certain collections tf.add_to_collection('images_final', images) tf.add_to_collection('logits_final', logits) def __build_chn_select_ops(self): """Build channel selection operations for convolutional layers. Returns: * chn_select_ops: list of channel selection operations (one per convolutional layer) """ # build layer-wise regression losses pattern = re.compile(r'/Conv2D$') conv_ops_full = get_ops_by_scope_n_pattern(self.model_scope_full, pattern) conv_ops_prnd = get_ops_by_scope_n_pattern(self.model_scope_prnd, pattern) reg_losses = [] for conv_op_full, conv_op_prnd in zip(conv_ops_full, conv_ops_prnd): reg_losses += [ tf.nn.l2_loss(conv_op_full.outputs[0] - conv_op_prnd.outputs[0]) ] # build layer-wise sampling operations conv_info_list = [] for idx_layer, (conv_op_full, conv_op_prnd) in enumerate( zip(conv_ops_full, conv_ops_prnd)): conv_krnl_shape = self.vars_prnd['conv_krnl'][idx_layer].shape conv_krnl_prnd_ph = tf.placeholder(tf.float32, shape=conv_krnl_shape, name='conv_krnl_prnd_ph_%d' % idx_layer) conv_info_list += [{ 'conv_krnl_full': self.vars_full['conv_krnl'][idx_layer], 'conv_krnl_prnd': self.vars_prnd['conv_krnl'][idx_layer], 'conv_krnl_prnd_ph': conv_krnl_prnd_ph, 'update_op': self.vars_prnd['conv_krnl'][idx_layer].assign( conv_krnl_prnd_ph), 'input_full': conv_op_full.inputs[0], 'input_prnd': conv_op_prnd.inputs[0], 'output_full': conv_op_full.outputs[0], 'output_prnd': conv_op_prnd.outputs[0], 'strides': conv_op_full.get_attr('strides'), 'padding': conv_op_full.get_attr('padding').decode('utf-8'), }] # build meta LASSO/least-square optimization problems self.meta_lasso = self.__build_meta_lasso() self.meta_lstsq = self.__build_meta_lstsq() self.reg_losses = reg_losses self.conv_info_list = conv_info_list self.nb_conv_layers = len(self.reg_losses) def __build_meta_lasso(self): """Build a meta LASSO optimization problem.""" # build a meta LASSO optimization problem with tf.variable_scope('meta_lasso'): # create placeholders to customize the LASSO problem xt_x_ph = tf.placeholder(tf.float32, name='xt_x_ph') xt_y_ph = tf.placeholder(tf.float32, name='xt_y_ph') mask_ph = tf.placeholder(tf.float32, name='mask_ph') gamma = tf.placeholder(tf.float32, shape=[], name='gamma') # create variables xt_x = tf.get_variable('xt_x', initializer=xt_x_ph, trainable=False, validate_shape=False) xt_y = tf.get_variable('xt_y', initializer=xt_y_ph, trainable=False, validate_shape=False) mask = tf.get_variable('mask', initializer=mask_ph, trainable=True, validate_shape=False) # TF operations def prox_mapping(x, thres): return tf.where( x > thres, x - thres, tf.where(x < -thres, x + thres, tf.zeros_like(x))) mask_gd = mask - FLAGS.cpr_ista_lrn_rate * (tf.matmul(xt_x, mask) - xt_y) train_op = mask.assign( prox_mapping(mask_gd, gamma * FLAGS.cpr_ista_lrn_rate)) init_op = tf.variables_initializer([xt_x, xt_y, mask]) # pack placeholders, variables, and TF operations into dict meta_lasso = { 'xt_x_ph': xt_x_ph, 'xt_y_ph': xt_y_ph, 'mask_ph': mask_ph, 'gamma': gamma, 'xt_x': xt_x, 'xy_y': xt_y, 'mask': mask, 'init_op': init_op, 'train_op': train_op, } return meta_lasso def __build_meta_lstsq(self): """Build a meta least-square optimization problem.""" # build a meta least-square optimization problem with tf.variable_scope('meta_lstsq'): # create placeholders to customize the least-square problem feat_mat_ph = tf.placeholder(tf.float32, name='feat_mat_ph') rspn_mat_ph = tf.placeholder(tf.float32, name='rspn_mat_ph') # compute the closed-form solution wei_mat = tf.linalg.lstsq(feat_mat_ph, rspn_mat_ph, FLAGS.loss_w_dcy) # pack placeholders and variables into dict meta_lstsq = { 'feat_mat_ph': feat_mat_ph, 'rspn_mat_ph': rspn_mat_ph, 'wei_mat': wei_mat, } return meta_lstsq def __calc_grads_pruned(self, grads_origin): """Calculate the mask-pruned gradients. Args: * grads_origin: list of original gradients Returns: * grads_pruned: list of mask-pruned gradients """ grads_pruned = [] conv_krnl_names = [var.name for var in self.vars_prnd['conv_krnl']] for grad in grads_origin: if grad[1].name not in conv_krnl_names: grads_pruned += [grad] else: idx_mask = conv_krnl_names.index(grad[1].name) grads_pruned += [(grad[0] * self.masks[idx_mask], grad[1])] return grads_pruned def __choose_channels(self): # pylint: disable=too-many-locals """Choose channels for all convolutional layers.""" # obtain each layer's pruning ratio prune_ratios = [FLAGS.cpr_prune_ratio] * self.nb_conv_layers if FLAGS.cpr_skip_frst_layer: prune_ratios[0] = 0.0 if FLAGS.cpr_skip_last_layer: prune_ratios[-1] = 0.0 # select channels for all the convolutional layers nb_workers = mgw.size() if FLAGS.enbl_multi_gpu else 1 for idx_layer, (prune_ratio, conv_info) in enumerate( zip(prune_ratios, self.conv_info_list)): # skip if no pruning is required if prune_ratio == 0.0: continue if self.is_primary_worker('global'): tf.logging.info('layer #%d: pr = %.2f (target)' % (idx_layer, prune_ratio)) tf.logging.info('kernel shape = {}'.format( conv_info['conv_krnl_prnd'].shape)) # extract the current layer's information conv_krnl_full = self.sess_train.run(conv_info['conv_krnl_full']) conv_krnl_prnd = self.sess_train.run(conv_info['conv_krnl_prnd']) conv_krnl_prnd_ph = conv_info['conv_krnl_prnd_ph'] update_op = conv_info['update_op'] input_full_tf = conv_info['input_full'] input_prnd_tf = conv_info['input_prnd'] output_full_tf = conv_info['output_full'] output_prnd_tf = conv_info['output_prnd'] strides = conv_info['strides'] padding = conv_info['padding'] nb_chns_input = conv_krnl_prnd.shape[2] # sample inputs & outputs through multiple mini-batches nb_iters_smpl = int( math.ceil(float(FLAGS.cpr_nb_smpl_insts) / FLAGS.batch_size)) inputs_list = [[] for __ in range(nb_chns_input)] outputs_list = [] for idx_iter in range(nb_iters_smpl): inputs_full, inputs_prnd, outputs_full, outputs_prnd = \ self.sess_train.run([input_full_tf, input_prnd_tf, output_full_tf, output_prnd_tf]) inputs_smpl, outputs_smpl = self.__smpl_inputs_n_outputs( conv_krnl_full, conv_krnl_prnd, inputs_full, inputs_prnd, outputs_full, outputs_prnd, strides, padding) for idx_chn_input in range(nb_chns_input): inputs_list[idx_chn_input] += [inputs_smpl[idx_chn_input]] outputs_list += [outputs_smpl] inputs_np_list = [np.vstack(x) for x in inputs_list] outputs_np = np.vstack(outputs_list) # choose channels via solving the sparsity-constrained regression problem conv_krnl_prnd = self.__solve_sparse_regression( inputs_np_list, outputs_np, conv_krnl_prnd, prune_ratio) self.sess_train.run(update_op, feed_dict={conv_krnl_prnd_ph: conv_krnl_prnd}) # evaluate the channel pruned model if FLAGS.cpr_eval_per_layer: if self.is_primary_worker('global'): self.__save_model(is_train=True) self.evaluate() self.auto_barrier() # evaluate the final channel pruned model if not FLAGS.cpr_eval_per_layer: if self.is_primary_worker('global'): self.__save_model(is_train=True) self.evaluate() self.auto_barrier() def __smpl_inputs_n_outputs(self, conv_krnl_full, conv_krnl_prnd, inputs_full, inputs_prnd, outputs_full, outputs_prnd, strides, padding): """Sample inputs & outputs of sub-regions from full feature maps. Args: Returns: """ # obtain parameters bs = inputs_full.shape[0] kh, kw = conv_krnl_full.shape[0], conv_krnl_full.shape[1] ih, iw, ic = inputs_full.shape[1], inputs_full.shape[ 2], inputs_full.shape[3] oh, ow, oc = outputs_full.shape[1], outputs_full.shape[ 2], outputs_full.shape[3] if padding == 'VALID': ph, pw = 0, 0 else: ph = int(math.ceil((kh - 1) / 2)) pw = int(math.ceil((kw - 1) / 2)) # perform zero-padding on input feature maps if ph == 0 and pw == 0: inputs_full_pad = inputs_full inputs_prnd_pad = inputs_prnd else: inputs_full_pad = np.pad(inputs_full, ((0, ), (ph, ), (pw, ), (0, )), 'constant') inputs_prnd_pad = np.pad(inputs_prnd, ((0, ), (ph, ), (pw, ), (0, )), 'constant') # sample inputs & outputs of sub-regions inputs_smpl_full_list = [] inputs_smpl_prnd_list = [] outputs_smpl_full_list = [] outputs_smpl_prnd_list = [] for idx_iter in range(FLAGS.cpr_nb_smpl_crops): idx_oh = np.random.randint(oh) idx_ow = np.random.randint(ow) idx_ih_low = idx_oh * strides[1] idx_ih_hgh = idx_ih_low + kh idx_iw_low = idx_ow * strides[2] idx_iw_hgh = idx_iw_low + kw inputs_smpl_full_list += [ inputs_full_pad[:, idx_ih_low:idx_ih_hgh, idx_iw_low:idx_iw_hgh, :] ] inputs_smpl_prnd_list += [ inputs_prnd_pad[:, idx_ih_low:idx_ih_hgh, idx_iw_low:idx_iw_hgh, :] ] outputs_smpl_full_list += [ np.reshape(outputs_full[:, idx_oh, idx_ow, :], [bs, -1]) ] outputs_smpl_prnd_list += [ np.reshape(outputs_prnd[:, idx_oh, idx_ow, :], [bs, -1]) ] # concatenate samples into a single np.array inputs_smpl_full = np.concatenate(inputs_smpl_full_list, axis=0) inputs_smpl_prnd = np.concatenate(inputs_smpl_prnd_list, axis=0) outputs_smpl_full = np.vstack(outputs_smpl_full_list) outputs_smpl_prnd = np.vstack(outputs_smpl_prnd_list) # validate inputs & outputs wei_mat_full = np.reshape(conv_krnl_full, [-1, oc]) wei_mat_prnd = np.reshape(conv_krnl_prnd, [-1, oc]) preds_smpl_full = np.matmul( np.reshape(inputs_smpl_full, [-1, kh * kw * ic]), wei_mat_full) preds_smpl_prnd = np.matmul( np.reshape(inputs_smpl_prnd, [-1, kh * kw * ic]), wei_mat_prnd) err_full = norm(outputs_smpl_full - preds_smpl_full)**2 / outputs_smpl_full.shape[0] err_prnd = norm(outputs_smpl_prnd - preds_smpl_prnd)**2 / outputs_smpl_prnd.shape[0] assert err_full < 1e-10, 'unable to recover output feature maps - full (%e)' % err_full assert err_prnd < 1e-10, 'unable to recover output feature maps - prnd (%e)' % err_prnd # concatenate sampled inputs & outputs arrays inputs_smpl = np.split(inputs_smpl_prnd, ic, axis=3) # one per input channel for idx in range(ic): inputs_smpl[idx] = np.reshape(inputs_smpl[idx], [-1, kh * kw]) outputs_smpl = outputs_smpl_full return inputs_smpl, outputs_smpl def __solve_sparse_regression(self, inputs_np_list, outputs_np, conv_krnl, prune_ratio): """Solve the sparsity-constrained regression problem. Args: * inputs_np_list: list of input feature maps (one per input channel, N x k^2) * outputs_np: output feature maps (N x c_o) * conv_krnl: initial convolutional kernel (k * k * c_i * c_o) * prune_ratio: pruning ratio Returns: * conv_krnl: updated convolutional kernel (k * k * c_i * c_o) """ # obtain parameters bs = outputs_np.shape[0] kh, kw, ic, oc = conv_krnl.shape[0], conv_krnl.shape[ 1], conv_krnl.shape[2], conv_krnl.shape[3] # compute the feature matrix & response vector rspn_vec_np = np.reshape(outputs_np, [-1, 1]) # N' x 1 (N' = N * c_o) feat_mat_np = np.zeros((rspn_vec_np.shape[0], ic)) # N' x c_i for idx in range(ic): wei_mat = np.reshape(conv_krnl[:, :, idx, :], [kh * kw, oc]) feat_mat_np[:, idx] = np.matmul(inputs_np_list[idx], wei_mat).ravel() # compute <X^T * X> & <X^T * y> in advance xt_x_np = np.matmul(feat_mat_np.T, feat_mat_np) / bs xt_y_np = np.matmul(feat_mat_np.T, rspn_vec_np) / bs mask_np_init = np.ones((ic, 1)) # solve the LASSO problem def __solve_lasso(x): self.sess_train.run(self.meta_lasso['init_op'], feed_dict={ self.meta_lasso['xt_x_ph']: xt_x_np, self.meta_lasso['xt_y_ph']: xt_y_np, self.meta_lasso['mask_ph']: mask_np_init, }) for __ in range(FLAGS.cpr_ista_nb_iters): self.sess_train.run(self.meta_lasso['train_op'], feed_dict={self.meta_lasso['gamma']: x}) mask_np = self.sess_train.run(self.meta_lasso['mask']) nb_chns_nnz = np.count_nonzero(mask_np) tf.logging.info('x = %e -> nb_chns_nnz = %d' % (x, nb_chns_nnz)) return mask_np, nb_chns_nnz # determine <gamma> via binary search val = 0.1 nb_chns_nnz_target = int(ic * (1.0 - prune_ratio)) while True: mask_np, nb_chns_nnz = __solve_lasso(val) if nb_chns_nnz > nb_chns_nnz_target: val *= 2.0 else: break lbnd = val / 2.0 ubnd = val while True: val = (lbnd + ubnd) / 2.0 mask_np, nb_chns_nnz = __solve_lasso(val) if nb_chns_nnz < nb_chns_nnz_target: ubnd = val elif nb_chns_nnz > nb_chns_nnz_target: lbnd = val else: break tf.logging.info('gamma-final: %e' % val) # construct a least-square regression problem rspn_mat_np = outputs_np bnry_vec_np = (mask_np > 0.0) inputs_np_list_msk = [ bnry_vec_np[idx] * inputs_np_list[idx] for idx in range(ic) ] feat_mat_np = np.reshape( np.concatenate( [np.expand_dims(x, axis=-1) for x in inputs_np_list_msk], axis=-1), [bs, -1]) wei_mat_np = self.sess_train.run(self.meta_lstsq['wei_mat'], feed_dict={ self.meta_lstsq['feat_mat_ph']: feat_mat_np, self.meta_lstsq['rspn_mat_ph']: rspn_mat_np, }) conv_krnl = np.reshape(wei_mat_np, conv_krnl.shape) * np.reshape( bnry_vec_np, [1, 1, -1, 1]) return conv_krnl def __save_model(self, is_train): """Save the current model for training or evaluation. Args: * is_train: whether to save a model for training """ if is_train: save_path = self.saver_prnd_train.save(self.sess_train, FLAGS.cpr_save_path, self.global_step) else: save_path = self.saver_prnd_eval.save(self.sess_eval, FLAGS.cpr_save_path_eval) tf.logging.info('model saved to ' + save_path) def __restore_model(self, is_train): """Restore a model from the latest checkpoint files. Args: * is_train: whether to restore a model for training """ save_path = tf.train.latest_checkpoint( os.path.dirname(FLAGS.cpr_save_path)) if is_train: self.saver_prnd_train.restore(self.sess_train, save_path) else: self.saver_prnd_eval.restore(self.sess_eval, save_path) tf.logging.info('model restored from ' + save_path) def __monitor_progress(self, summary, log_rslt, idx_iter, time_step): """Monitor the training progress. Args: * summary: summary protocol buffer * log_rslt: logging operations' results * idx_iter: index of the training iteration * time_step: time step between two summary operations """ # write summaries for TensorBoard visualization self.sm_writer.add_summary(summary, idx_iter) # compute the training speed speed = FLAGS.batch_size * FLAGS.summ_step / time_step if FLAGS.enbl_multi_gpu: speed *= mgw.size() # display monitored statistics log_str = ' | '.join([ '%s = %.4e' % (name, value) for name, value in zip(self.log_op_names, log_rslt) ]) tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' % (idx_iter + 1, log_str, speed))
class ChannelPrunedRmtLearner(AbstractLearner): # pylint: disable=too-many-instance-attributes """Channel pruning learner - remastered.""" def __init__(self, sm_writer, model_helper): """Constructor function. Args: * sm_writer: TensorFlow's summary writer * model_helper: model helper with definitions of model & dataset """ # class-independent initialization super(ChannelPrunedRmtLearner, self).__init__(sm_writer, model_helper) # define scopes for full & channel-pruned models self.model_scope_full = 'model' self.model_scope_prnd = 'pruned_model' # download the pre-trained model if self.is_primary_worker('local'): self.download_model() # pre-trained model is required self.auto_barrier() tf.logging.info('model files: ' + ', '.join(os.listdir('./models'))) # class-dependent initialization if FLAGS.enbl_dst: self.helper_dst = DistillationHelper(sm_writer, model_helper, self.mpi_comm) self.__build_train() self.__build_eval() # build the channel pruning graph self.__build_prune() def train(self): """Train a model and periodically produce checkpoint files.""" # choose channels or directly load a pre-pruned model as warm-start if not FLAGS.cpr_warm_start: time_prev = timer() self.__choose_channels() tf.logging.info('time (channel selection): %.2f (s)' % (timer() - time_prev)) save_path = tf.train.latest_checkpoint( os.path.dirname(FLAGS.cpr_save_path_ws)) self.saver_prnd_train.restore(self.sess_train, save_path) tf.logging.info('model restored from ' + save_path) # initialize all the remaining variables and then broadcast self.sess_train.run(self.init_op) if FLAGS.enbl_multi_gpu: self.sess_train.run(self.bcast_op) # evaluate the model before fine-tuning if self.is_primary_worker('global'): self.__save_model(is_train=True) self.evaluate() self.auto_barrier() # fine-tune the model with chosen channels only time_prev = timer() for idx_iter in range(self.nb_iters_train): # train the model if (idx_iter + 1) % FLAGS.summ_step != 0: self.sess_train.run(self.train_op) else: __, summary, log_rslt = self.sess_train.run( [self.train_op, self.summary_op, self.log_op]) if self.is_primary_worker('global'): time_step = timer() - time_prev self.__monitor_progress(summary, log_rslt, idx_iter, time_step) time_prev = timer() # save the model at certain steps if self.is_primary_worker('global') and (idx_iter + 1) % FLAGS.save_step == 0: self.__save_model(is_train=True) self.evaluate() self.auto_barrier() # save the final model if self.is_primary_worker('global'): self.__save_model(is_train=True) self.__restore_model(is_train=False) self.__save_model(is_train=False) self.evaluate() def evaluate(self): """Restore a model from the latest checkpoint files and then evaluate it.""" self.__restore_model(is_train=False) nb_iters = int( np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval)) eval_rslts = np.zeros((nb_iters, len(self.eval_op))) self.dump_n_eval(outputs=None, action='init') for idx_iter in range(nb_iters): if (idx_iter + 1) % 100 == 0: tf.logging.info('process the %d-th mini-batch for evaluation' % (idx_iter + 1)) eval_rslts[idx_iter], outputs = self.sess_eval.run( [self.eval_op, self.outputs_eval]) self.dump_n_eval(outputs=outputs, action='dump') self.dump_n_eval(outputs=None, action='eval') for idx, name in enumerate(self.eval_op_names): tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx]))) def __build_train(self): # pylint: disable=too-many-locals,too-many-statements """Build the training graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.allow_growth = True # pylint: disable=no-member config.gpu_options.visible_device_list = \ str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - channel-pruned model with tf.variable_scope(self.model_scope_prnd): logits_prnd = self.forward_train(images) self.vars_prnd = get_vars_by_scope(self.model_scope_prnd) self.global_step = tf.train.get_or_create_global_step() self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'] + [self.global_step]) # loss & extra evaluation metrics loss, metrics = self.calc_loss(labels, logits_prnd, self.vars_prnd['trainable']) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits_prnd, logits_dst) tf.summary.scalar('loss', loss) for key, value in metrics.items(): tf.summary.scalar(key, value) # learning rate schedule lrn_rate, self.nb_iters_train = self.setup_lrn_rate( self.global_step) # calculate pruning ratios pr_trainable = calc_prune_ratio(self.vars_prnd['trainable']) pr_conv_krnl = calc_prune_ratio(self.vars_prnd['conv_krnl']) tf.summary.scalar('pr_trainable', pr_trainable) tf.summary.scalar('pr_conv_krnl', pr_conv_krnl) # create masks and corresponding operations for channel pruning self.masks = [] for idx, var in enumerate(self.vars_prnd['conv_krnl']): tf.logging.info( 'creating a pruning mask for {} of size {}'.format( var.name, var.shape)) mask_name = '/'.join(var.name.split('/')[1:]).replace( ':0', '_mask') var_norm = tf.reduce_sum(tf.square(var), axis=[0, 1, 3], keepdims=True) mask_init = tf.cast(var_norm > 0.0, tf.float32) mask = tf.get_variable(mask_name, initializer=mask_init, trainable=False) self.masks += [mask] # optimizer & gradients optimizer_base = tf.train.MomentumOptimizer( lrn_rate, FLAGS.momentum) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) grads_origin = optimizer.compute_gradients( loss, self.vars_prnd['trainable']) grads_pruned = self.__calc_grads_pruned(grads_origin) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_prnd) with tf.control_dependencies(update_ops): self.train_op = optimizer.apply_gradients( grads_pruned, global_step=self.global_step) # TF operations for logging & summarizing self.sess_train = sess self.summary_op = tf.summary.merge_all() self.init_op = tf.group( tf.variables_initializer([self.global_step] + self.masks + optimizer_base.variables())) self.log_op = [lrn_rate, loss, pr_trainable, pr_conv_krnl] + list( metrics.values()) self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_krn'] + list( metrics.keys()) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) def __build_eval(self): """Build the evaluation graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.allow_growth = True # pylint: disable=no-member config.gpu_options.visible_device_list = \ str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member self.sess_eval = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_eval() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_eval, images) # model definition - channel-pruned model with tf.variable_scope(self.model_scope_prnd): logits = self.forward_eval(images) vars_prnd = get_vars_by_scope(self.model_scope_prnd) global_step = tf.train.get_or_create_global_step() self.saver_prnd_eval = tf.train.Saver(vars_prnd['all'] + [global_step]) # loss & extra evaluation metrics loss, metrics = self.calc_loss(labels, logits, vars_prnd['trainable']) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) # calculate pruning ratios pr_trainable = calc_prune_ratio(vars_prnd['trainable']) pr_conv_krnl = calc_prune_ratio(vars_prnd['conv_krnl']) # TF operations for evaluation self.eval_op = [loss, pr_trainable, pr_conv_krnl] + list( metrics.values()) self.eval_op_names = ['loss', 'pr_trn', 'pr_krn'] + list( metrics.keys()) self.outputs_eval = logits # add input & output tensors to certain collections tf.add_to_collection('images_final', images) tf.add_to_collection('logits_final', logits) def __build_prune(self): """Build the channel pruning graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.allow_growth = True # pylint: disable=no-member config.gpu_options.visible_device_list = \ str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() if not isinstance(images, dict): images_ph = tf.placeholder(tf.float32, shape=images.shape, name='images_ph') else: images_ph = {} for key, value in images.items(): images_ph[key] = tf.placeholder(value.dtype, shape=value.shape, name=(key + '_ph')) # restore a pre-trained model as full model with tf.variable_scope(self.model_scope_full): __ = self.forward_train(images_ph) vars_full = get_vars_by_scope(self.model_scope_full) saver_full = tf.train.Saver(vars_full['all']) saver_full.restore( sess, tf.train.latest_checkpoint(os.path.dirname( FLAGS.save_path))) # restore a pre-trained model as channel-pruned model with tf.variable_scope(self.model_scope_prnd): logits_prnd = self.forward_train(images_ph) vars_prnd = get_vars_by_scope(self.model_scope_prnd) global_step = tf.train.get_or_create_global_step() saver_prnd = tf.train.Saver(vars_prnd['all'] + [global_step]) # loss & extra evaluation metrics loss, metrics = self.calc_loss(labels, logits_prnd, vars_prnd['trainable']) # calculate pruning ratios pr_trainable = calc_prune_ratio(vars_prnd['trainable']) pr_conv_krnl = calc_prune_ratio(vars_prnd['conv_krnl']) # use full model's weights to initialize channel-pruned model init_ops = [global_step.initializer] for var_full, var_prnd in zip(vars_full['all'], vars_prnd['all']): init_ops += [var_prnd.assign(var_full)] self.init_op_prune = tf.group(init_ops) # build a list of Conv2D operation's information self.conv_info_list = self.__build_conv_info_list( vars_prnd['conv_krnl']) # build meta LASSO/least-square optimization problems self.meta_lasso = self.__build_meta_lasso() self.meta_lstsq = self.__build_meta_lstsq() # TF operations for logging & summarizing self.sess_prune = sess self.images_prune = images self.images_prune_ph = images_ph self.saver_prune = saver_prnd self.pr_trn_prune = pr_trainable self.pr_krn_prune = pr_conv_krnl def __build_conv_info_list(self, conv_krnls_prnd): """Build a list of Conv2D operation's information. Args: * conv_krnls_prnd: list of convolutional kernels in the channel-pruned model Returns: * conv_info_list: list of Conv2D operation's information """ # find all the Conv2D operations pattern = re.compile(r'/Conv2D$') conv_ops_full = get_ops_by_scope_n_pattern(self.model_scope_full, pattern) conv_ops_prnd = get_ops_by_scope_n_pattern(self.model_scope_prnd, pattern) # build a list of Conv2D operation's information conv_info_list = [] for idx_layer, (conv_op_full, conv_op_prnd) in enumerate( zip(conv_ops_full, conv_ops_prnd)): conv_krnl_prnd = conv_krnls_prnd[idx_layer] conv_krnl_prnd_ph = tf.placeholder(tf.float32, shape=conv_krnl_prnd.shape, name='conv_krnl_prnd_ph_%d' % idx_layer) conv_info_list += [{ 'conv_krnl_full': conv_op_full.inputs[1], 'conv_krnl_prnd': conv_op_prnd.inputs[1], 'conv_krnl_prnd_ph': conv_krnl_prnd_ph, 'update_op': conv_krnl_prnd.assign(conv_krnl_prnd_ph), 'input_full': conv_op_full.inputs[0], 'input_prnd': conv_op_prnd.inputs[0], 'output_full': conv_op_full.outputs[0], 'output_prnd': conv_op_prnd.outputs[0], 'strides': conv_op_full.get_attr('strides'), 'padding': conv_op_full.get_attr('padding').decode('utf-8'), }] return conv_info_list def __build_meta_lasso(self): """Build a meta LASSO optimization problem.""" # build a meta LASSO optimization problem with tf.variable_scope('meta_lasso'): # create placeholders to customize the LASSO problem xt_x_ph = tf.placeholder(tf.float32, name='xt_x_ph') xt_y_ph = tf.placeholder(tf.float32, name='xt_y_ph') mask_ph = tf.placeholder(tf.float32, name='mask_ph') gamma = tf.placeholder(tf.float32, shape=[], name='gamma') # create variables xt_x = tf.get_variable('xt_x', initializer=xt_x_ph, trainable=False, validate_shape=False) xt_y = tf.get_variable('xt_y', initializer=xt_y_ph, trainable=False, validate_shape=False) mask = tf.get_variable('mask', initializer=mask_ph, trainable=True, validate_shape=False) # TF operations def prox_mapping(x, thres): return tf.where( x > thres, x - thres, tf.where(x < -thres, x + thres, tf.zeros_like(x))) mask_gd = mask - FLAGS.cpr_ista_lrn_rate * (tf.matmul(xt_x, mask) - xt_y) train_op = mask.assign( prox_mapping(mask_gd, gamma * FLAGS.cpr_ista_lrn_rate)) init_op = tf.variables_initializer([xt_x, xt_y, mask]) # pack placeholders, variables, and TF operations into dict meta_lasso = { 'xt_x_ph': xt_x_ph, 'xt_y_ph': xt_y_ph, 'mask_ph': mask_ph, 'gamma': gamma, 'xt_x': xt_x, 'xy_y': xt_y, 'mask': mask, 'init_op': init_op, 'train_op': train_op, } return meta_lasso def __build_meta_lstsq(self): """Build a meta least-square optimization problem.""" # build a meta least-square optimization problem beta1 = 0.9 beta2 = 0.999 epsilon = 1e-8 with tf.variable_scope('meta_lstsq'): # create placeholders to customize the least-square problem x_mat_ph = tf.placeholder(tf.float32, name='x_mat_ph') y_mat_ph = tf.placeholder(tf.float32, name='y_mat_ph') w_mat_ph = tf.placeholder(tf.float32, name='w_mat_ph') gacc1_ph = tf.placeholder(tf.float32, name='gacc1_ph') gacc2_ph = tf.placeholder(tf.float32, name='gacc2_ph') # create variables x_mat = tf.get_variable('x_mat', initializer=x_mat_ph, validate_shape=False) y_mat = tf.get_variable('y_mat', initializer=y_mat_ph, validate_shape=False) w_mat = tf.get_variable('w_mat', initializer=w_mat_ph, validate_shape=False) gacc1 = tf.get_variable('gacc1', initializer=gacc1_ph, validate_shape=False) gacc2 = tf.get_variable('gacc2', initializer=gacc2_ph, validate_shape=False) train_step = tf.get_variable('train_step', shape=[], initializer=tf.zeros_initializer) # TF operations nb_smpls = tf.cast(tf.shape(x_mat)[0], tf.float32) loss_reg = tf.nn.l2_loss(tf.matmul(x_mat, w_mat) - y_mat) / nb_smpls loss_dcy = FLAGS.loss_w_dcy * tf.nn.l2_loss(w_mat) grad = tf.matmul(tf.transpose(x_mat), tf.matmul(x_mat, w_mat) - y_mat) / nb_smpls + FLAGS.loss_w_dcy * w_mat update_ops = [ gacc1.assign(beta1 * gacc1 + (1.0 - beta1) * grad), gacc2.assign(beta2 * gacc2 + (1.0 - beta2) * grad**2), train_step.assign_add(tf.ones([])) ] with tf.control_dependencies(update_ops): lrn_rate = FLAGS.cpr_lstsq_lrn_rate \ * tf.sqrt(1.0 - tf.pow(beta2, train_step)) / (1.0 - tf.pow(beta1, train_step)) train_op = w_mat.assign_add(-lrn_rate * gacc1 / (tf.sqrt(gacc2) + epsilon)) init_op = tf.variables_initializer( [x_mat, y_mat, w_mat, gacc1, gacc2, train_step]) # pack placeholders and variables into dict meta_lstsq = { 'x_mat_ph': x_mat_ph, 'y_mat_ph': y_mat_ph, 'w_mat_ph': w_mat_ph, 'gacc1_ph': gacc1_ph, 'gacc2_ph': gacc2_ph, 'w_mat': w_mat, 'loss_reg': loss_reg, 'loss_dcy': loss_dcy, 'init_op': init_op, 'train_op': train_op, } return meta_lstsq def __calc_grads_pruned(self, grads_origin): """Calculate the mask-pruned gradients. Args: * grads_origin: list of original gradients Returns: * grads_pruned: list of mask-pruned gradients """ grads_pruned = [] conv_krnl_names = [var.name for var in self.vars_prnd['conv_krnl']] for grad in grads_origin: if grad[1].name not in conv_krnl_names: grads_pruned += [grad] else: idx_mask = conv_krnl_names.index(grad[1].name) grads_pruned += [(grad[0] * self.masks[idx_mask], grad[1])] return grads_pruned def __choose_channels(self): # pylint: disable=too-many-locals """Choose channels for all convolutional layers.""" # configure each layer's pruning ratio nb_layers = len(self.conv_info_list) prune_ratios = [FLAGS.cpr_prune_ratio] * nb_layers if FLAGS.cpr_skip_frst_layer: prune_ratios[0] = 0.0 if FLAGS.cpr_skip_last_layer: prune_ratios[-1] = 0.0 # skip channel pruning at certain layers skip_names = FLAGS.cpr_skip_op_names.split( ',') if FLAGS.cpr_skip_op_names is not None else [] for idx_layer in range(nb_layers): # if self.conv_info_list[idx_layer]['input_full'].shape[2] == 8: # prune_ratios[idx_layer] = 0.0 conv_krnl_prnd_name = self.conv_info_list[idx_layer][ 'conv_krnl_prnd'].name for skip_name in skip_names: if skip_name in conv_krnl_prnd_name: prune_ratios[idx_layer] = 0.0 tf.logging.info('skip %s since no pruning is required' % conv_krnl_prnd_name) break # cache multiple mini-batches of images for channel selection def __build_feed_dict(images_np): if not isinstance(self.images_prune, dict): feed_dict = {self.images_prune_ph: images_np} else: feed_dict = {} for key in self.images_prune: feed_dict[self.images_prune_ph[key]] = images_np[key] return feed_dict nb_mbtcs = int(math.ceil(FLAGS.cpr_nb_smpls / FLAGS.batch_size)) images_cached = [] for __ in range(nb_mbtcs): images_cached += [self.sess_prune.run(self.images_prune)] # select channels for all the convolutional layers self.sess_prune.run(self.init_op_prune) for idx_layer in range(nb_layers): # display the layer information prune_ratio = prune_ratios[idx_layer] conv_info = self.conv_info_list[idx_layer] if self.is_primary_worker('global'): tf.logging.info('layer #%d: pr = %.2f (target)' % (idx_layer, prune_ratio)) tf.logging.info('kernel name = {}'.format( conv_info['conv_krnl_prnd'].name)) tf.logging.info('kernel shape = {}'.format( conv_info['conv_krnl_prnd'].shape)) # extract the current layer's information conv_krnl_full = self.sess_prune.run(conv_info['conv_krnl_full']) conv_krnl_prnd = self.sess_prune.run(conv_info['conv_krnl_prnd']) conv_krnl_prnd_ph = conv_info['conv_krnl_prnd_ph'] update_op = conv_info['update_op'] input_full_tf = conv_info['input_full'] input_prnd_tf = conv_info['input_prnd'] output_full_tf = conv_info['output_full'] output_prnd_tf = conv_info['output_prnd'] strides = conv_info['strides'] padding = conv_info['padding'] nb_chns_input = conv_krnl_prnd.shape[2] # sample inputs & outputs through multiple mini-batches tf.logging.info( 'sampling inputs & outputs through multiple mini-batches') time_beg = timer() nb_insts = 0 # number of sampled instances (for regression) collected so far nb_insts_min = FLAGS.cpr_nb_crops_per_smpl * FLAGS.cpr_nb_smpls # minimal requirement inputs_list = [[] for __ in range(nb_chns_input)] outputs_list = [] for idx_mbtc in range(nb_mbtcs): inputs_full, inputs_prnd, outputs_full, outputs_prnd = \ self.sess_prune.run([input_full_tf, input_prnd_tf, output_full_tf, output_prnd_tf], feed_dict=__build_feed_dict(images_cached[idx_mbtc])) inputs_smpl, outputs_smpl = self.__smpl_inputs_n_outputs( conv_krnl_full, conv_krnl_prnd, inputs_full, inputs_prnd, outputs_full, outputs_prnd, strides, padding) nb_insts += outputs_smpl.shape[0] for idx_chn_input in range(nb_chns_input): inputs_list[idx_chn_input] += [inputs_smpl[idx_chn_input]] outputs_list += [outputs_smpl] if nb_insts > nb_insts_min: break idxs_inst = np.random.choice(nb_insts, size=(nb_insts_min), replace=False) inputs_np_list = [np.vstack(x)[idxs_inst] for x in inputs_list] outputs_np = np.vstack(outputs_list)[idxs_inst] tf.logging.info('time elapsed (sampling): %.4f (s)' % (timer() - time_beg)) # choose channels via solving the sparsity-constrained regression problem tf.logging.info( 'choosing channels via solving the sparsity-constrained regression problem' ) time_beg = timer() conv_krnl_prnd = self.__solve_sparse_regression( inputs_np_list, outputs_np, conv_krnl_prnd, prune_ratio) self.sess_prune.run(update_op, feed_dict={conv_krnl_prnd_ph: conv_krnl_prnd}) tf.logging.info('time elapsed (selection): %.4f (s)' % (timer() - time_beg)) # compute the overall pruning ratios pr_trn, pr_krn = self.sess_prune.run( [self.pr_trn_prune, self.pr_krn_prune]) tf.logging.info('pruning ratios: %e (trn) / %e (krn)' % (pr_trn, pr_krn)) # save the temporary model containing channel pruned weights if self.is_primary_worker('global'): save_path = self.saver_prune.save(self.sess_prune, FLAGS.cpr_save_path_ws) tf.logging.info('model saved to ' + save_path) self.auto_barrier() def __smpl_inputs_n_outputs(self, conv_krnl_full, conv_krnl_prnd, inputs_full, inputs_prnd, outputs_full, outputs_prnd, strides, padding): """Sample inputs & outputs of sub-regions from full feature maps. Args: Returns: """ # obtain parameters bs = inputs_full.shape[0] kh, kw = conv_krnl_full.shape[0], conv_krnl_full.shape[1] ih, iw, ic = inputs_full.shape[1], inputs_full.shape[ 2], inputs_full.shape[3] oh, ow, oc = outputs_full.shape[1], outputs_full.shape[ 2], outputs_full.shape[3] sh, sw = strides[1], strides[2] if padding == 'VALID': pt, pb, pl, pr = 0, 0, 0, 0 # padding - top / bottom / left / right else: # ref link: https://www.tensorflow.org/api_guides/python/nn#Convolution ph = max(kh - (sh if ih % sh == 0 else ih % sh), 0) pw = max(kw - (sw if iw % sw == 0 else iw % sw), 0) pt, pb = ph // 2, ph % 2 pl, pr = pw // 2, pw % 2 # sample inputs & outputs of sub-regions inputs_smpl_full_list = [] inputs_smpl_prnd_list = [] outputs_smpl_full_list = [] outputs_smpl_prnd_list = [] for idx_iter in range(FLAGS.cpr_nb_crops_per_smpl): idx_oh = np.random.randint(oh) idx_ow = np.random.randint(ow) idx_ih_low = idx_oh * strides[ 1] - pt # uncropped indices of input feature maps idx_ih_hgh = idx_ih_low + kh idx_iw_low = idx_ow * strides[2] - pl idx_iw_hgh = idx_iw_low + kw idx_sh_low = max(-idx_ih_low, 0) # cropped indices of sampled feature maps idx_sh_hgh = kh - max(idx_ih_hgh - ih, 0) idx_sw_low = max(-idx_iw_low, 0) idx_sw_hgh = kw - max(idx_iw_hgh - iw, 0) idx_ih_low = max(idx_ih_low, 0) # cropped indices of input feature maps idx_ih_hgh = min(idx_ih_hgh, ih) idx_iw_low = max(idx_iw_low, 0) idx_iw_hgh = min(idx_iw_hgh, iw) inputs_smpl_full = np.zeros((bs, kh, kw, ic)) inputs_smpl_prnd = np.zeros((bs, kh, kw, ic)) inputs_smpl_full[:, idx_sh_low:idx_sh_hgh, idx_sw_low:idx_sw_hgh, :] = \ inputs_full[:, idx_ih_low:idx_ih_hgh, idx_iw_low:idx_iw_hgh, :] inputs_smpl_prnd[:, idx_sh_low:idx_sh_hgh, idx_sw_low:idx_sw_hgh, :] = \ inputs_prnd[:, idx_ih_low:idx_ih_hgh, idx_iw_low:idx_iw_hgh, :] inputs_smpl_full_list += [inputs_smpl_full] inputs_smpl_prnd_list += [inputs_smpl_prnd] outputs_smpl_full_list += [ np.reshape(outputs_full[:, idx_oh, idx_ow, :], [bs, -1]) ] outputs_smpl_prnd_list += [ np.reshape(outputs_prnd[:, idx_oh, idx_ow, :], [bs, -1]) ] # concatenate samples into a single np.array inputs_smpl_full = np.concatenate(inputs_smpl_full_list, axis=0) inputs_smpl_prnd = np.concatenate(inputs_smpl_prnd_list, axis=0) outputs_smpl_full = np.vstack(outputs_smpl_full_list) outputs_smpl_prnd = np.vstack(outputs_smpl_prnd_list) # concatenate sampled inputs & outputs arrays inputs_smpl = [ np.reshape(x, [-1, kh * kw]) for x in np.split(inputs_smpl_prnd, ic, axis=3) ] outputs_smpl = outputs_smpl_full # validate inputs & outputs wei_mat_full = np.reshape(conv_krnl_full, [-1, oc]) wei_mat_prnd = np.reshape(conv_krnl_prnd, [-1, oc]) preds_smpl_full = np.matmul( np.reshape(inputs_smpl_full, [-1, kh * kw * ic]), wei_mat_full) preds_smpl_prnd = np.matmul( np.reshape(inputs_smpl_prnd, [-1, kh * kw * ic]), wei_mat_prnd) err_full = norm(outputs_smpl_full - preds_smpl_full)**2 / outputs_smpl_full.size err_prnd = norm(outputs_smpl_prnd - preds_smpl_prnd)**2 / outputs_smpl_prnd.size assert err_full < 1e-6, 'unable to recover output feature maps - full (%e)' % err_full assert err_prnd < 1e-6, 'unable to recover output feature maps - prnd (%e)' % err_prnd return inputs_smpl, outputs_smpl def __solve_sparse_regression(self, inputs_np_list, outputs_np, conv_krnl, prune_ratio): """Solve the sparsity-constrained regression problem. Args: * inputs_np_list: list of input feature maps (one per input channel, N x k^2) * outputs_np: output feature maps (N x c_o) * conv_krnl: initial convolutional kernel (k * k * c_i * c_o) * prune_ratio: pruning ratio Returns: * conv_krnl: updated convolutional kernel (k * k * c_i * c_o) """ # obtain parameters bs = outputs_np.shape[0] kh, kw, ic, oc = conv_krnl.shape[0], conv_krnl.shape[ 1], conv_krnl.shape[2], conv_krnl.shape[3] nb_chns_nnz_target = int(ic * (1.0 - prune_ratio)) tf.logging.info('[sparse regression]') tf.logging.info( '\tinputs: {} / outputs: {} / conv_krnl: {} / pr: {} / nnz: {}'. format(inputs_np_list[0].shape, outputs_np.shape, conv_krnl.shape, prune_ratio, nb_chns_nnz_target)) # compute the feature matrix & response vector tf.logging.info('computing the feature matrix & response vector') time_beg = timer() bs_rdc = int(math.ceil(min(bs, bs / oc * 10.0))) tf.logging.info('secondary sampling: %d -> %d' % (bs, bs_rdc)) idxs_inst = np.random.choice(bs, size=(bs_rdc), replace=False) rspn_vec_np = np.reshape(outputs_np[idxs_inst], [-1, 1]) # N' x 1 (N' = N * c_o) feat_mat_np = np.zeros((ic, bs_rdc * oc)) # c_i x N' for idx in range(ic): wei_mat = np.reshape(conv_krnl[:, :, idx, :], [kh * kw, oc]) feat_mat_np[idx] = np.matmul(inputs_np_list[idx][idxs_inst], wei_mat).ravel() feat_mat_np = np.transpose(feat_mat_np) tf.logging.info('time elapsed: %.4f (s)' % (timer() - time_beg)) # compute <X^T * X> & <X^T * y> in advance tf.logging.info('computing <X^T * X> & <X^T * y> in advance') time_beg = timer() xt_x_np = np.matmul(feat_mat_np.T, feat_mat_np) xt_y_np = np.matmul(feat_mat_np.T, rspn_vec_np) xt_x_norm = norm( xt_x_np ) # normalize <xt_x> to unit norm, and adjust <xt_y> correspondingly xt_x_np /= xt_x_norm xt_y_np /= xt_x_norm mask_np_init = np.random.uniform(size=(ic, 1)) tf.logging.info('time elapsed: %.4f (s)' % (timer() - time_beg)) # solve the LASSO problem def __solve_lasso(x): self.sess_prune.run(self.meta_lasso['init_op'], feed_dict={ self.meta_lasso['xt_x_ph']: xt_x_np, self.meta_lasso['xt_y_ph']: xt_y_np, self.meta_lasso['mask_ph']: mask_np_init, }) for __ in range(FLAGS.cpr_ista_nb_iters): self.sess_prune.run(self.meta_lasso['train_op'], feed_dict={self.meta_lasso['gamma']: x}) mask_np = self.sess_prune.run(self.meta_lasso['mask']) nb_chns_nnz = np.count_nonzero(mask_np) tf.logging.info('x = %e -> nb_chns_nnz = %d' % (x, nb_chns_nnz)) return mask_np, nb_chns_nnz # determine <gamma>'s upper bound tf.logging.info('determining <gamma>\'s upper bound') time_beg = timer() ubnd = 0.1 while True: mask_np, nb_chns_nnz = __solve_lasso(ubnd) if nb_chns_nnz <= nb_chns_nnz_target: break else: ubnd *= 2.0 tf.logging.info('time elapsed: %.4f (s)' % (timer() - time_beg)) # determine <gamma> via binary search tf.logging.info('determining <gamma> via binary search') time_beg = timer() lbnd = 0.0 while nb_chns_nnz != nb_chns_nnz_target and ubnd - lbnd > 1e-8: val = (lbnd + ubnd) / 2.0 mask_np, nb_chns_nnz = __solve_lasso(val) if nb_chns_nnz < nb_chns_nnz_target: ubnd = val elif nb_chns_nnz > nb_chns_nnz_target: lbnd = val else: break tf.logging.info('time elapsed: %.4f (s)' % (timer() - time_beg)) # construct a least-square regression problem tf.logging.info('constructing a least-square regression problem') time_beg = timer() bnry_vec_np = (np.abs(mask_np) > 0.0).astype(np.float32) rspn_mat_np = outputs_np feat_tns_np = np.concatenate( [np.expand_dims(x, axis=-1) for x in inputs_np_list], axis=-1) feat_mat_np = np.reshape( feat_tns_np * np.reshape(bnry_vec_np, [1, 1, -1]), [bs, -1]) w_mat_np_init = np.reshape(conv_krnl, [-1, oc]) gacc1_np = np.zeros_like(w_mat_np_init) gacc2_np = np.zeros_like(w_mat_np_init) self.sess_prune.run(self.meta_lstsq['init_op'], feed_dict={ self.meta_lstsq['x_mat_ph']: feat_mat_np, self.meta_lstsq['y_mat_ph']: rspn_mat_np, self.meta_lstsq['w_mat_ph']: w_mat_np_init, self.meta_lstsq['gacc1_ph']: gacc1_np, self.meta_lstsq['gacc2_ph']: gacc2_np, }) loss_reg, loss_dcy = self.sess_prune.run( [self.meta_lstsq['loss_reg'], self.meta_lstsq['loss_dcy']]) tf.logging.info('losses: %e (reg) / %e (dcy)' % (loss_reg, loss_dcy)) for __ in range(FLAGS.cpr_lstsq_nb_iters): self.sess_prune.run(self.meta_lstsq['train_op']) w_mat_np, loss_reg, loss_dcy = self.sess_prune.run([ self.meta_lstsq['w_mat'], self.meta_lstsq['loss_reg'], self.meta_lstsq['loss_dcy'] ]) tf.logging.info('losses: %e (reg) / %e (dcy)' % (loss_reg, loss_dcy)) conv_krnl = np.reshape(w_mat_np, conv_krnl.shape) * np.reshape( bnry_vec_np, [1, 1, -1, 1]) tf.logging.info('time elapsed: %.4f (s)' % (timer() - time_beg)) return conv_krnl def __save_model(self, is_train): """Save the current model for training or evaluation. Args: * is_train: whether to save a model for training """ if is_train: save_path = self.saver_prnd_train.save(self.sess_train, FLAGS.cpr_save_path, self.global_step) else: save_path = self.saver_prnd_eval.save(self.sess_eval, FLAGS.cpr_save_path_eval) tf.logging.info('model saved to ' + save_path) def __restore_model(self, is_train): """Restore a model from the latest checkpoint files. Args: * is_train: whether to restore a model for training """ save_path = tf.train.latest_checkpoint( os.path.dirname(FLAGS.cpr_save_path)) if is_train: self.saver_prnd_train.restore(self.sess_train, save_path) else: self.saver_prnd_eval.restore(self.sess_eval, save_path) tf.logging.info('model restored from ' + save_path) def __monitor_progress(self, summary, log_rslt, idx_iter, time_step): """Monitor the training progress. Args: * summary: summary protocol buffer * log_rslt: logging operations' results * idx_iter: index of the training iteration * time_step: time step between two summary operations """ # write summaries for TensorBoard visualization self.sm_writer.add_summary(summary, idx_iter) # compute the training speed speed = FLAGS.batch_size * FLAGS.summ_step / time_step if FLAGS.enbl_multi_gpu: speed *= mgw.size() # display monitored statistics log_str = ' | '.join([ '%s = %.4e' % (name, value) for name, value in zip(self.log_op_names, log_rslt) ]) tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' % (idx_iter + 1, log_str, speed))
class FullPrecLearner(AbstractLearner): # pylint: disable=too-many-instance-attributes """Full-precision learner (no model compression applied).""" def __init__(self, sm_writer, model_helper, model_scope=None, enbl_dst=None): """Constructor function. Args: * sm_writer: TensorFlow's summary writer * model_helper: model helper with definitions of model & dataset * model_scope: name scope in which to define the model * enbl_dst: whether to create a model with distillation loss """ # class-independent initialization super(FullPrecLearner, self).__init__(sm_writer, model_helper) model_scope = 'quan_model' # over-ride the model scope and distillation loss switch if model_scope is not None: self.model_scope = model_scope self.enbl_dst = enbl_dst if enbl_dst is not None else FLAGS.enbl_dst # class-dependent initialization if self.enbl_dst: self.helper_dst = DistillationHelper(sm_writer, model_helper, self.mpi_comm) self.__build(is_train=True) self.__build(is_train=False) def train(self): """Train a model and periodically produce checkpoint files.""" # initialization self.sess_train.run(self.init_op) if FLAGS.enbl_multi_gpu: self.sess_train.run(self.bcast_op) # train the model through iterations and periodically save & evaluate the model time_prev = timer() for idx_iter in range(self.nb_iters_train): # train the model if (idx_iter + 1) % FLAGS.summ_step != 0: self.sess_train.run(self.train_op) else: __, summary, log_rslt = self.sess_train.run( [self.train_op, self.summary_op, self.log_op]) if self.is_primary_worker('global'): time_step = timer() - time_prev self.__monitor_progress(summary, log_rslt, idx_iter, time_step) time_prev = timer() # save & evaluate the model at certain steps if self.is_primary_worker('global') and (idx_iter + 1) % FLAGS.save_step == 0: self.__save_model(is_train=True) self.evaluate() # save the final model if self.is_primary_worker('global'): self.__save_model(is_train=True) self.__restore_model(is_train=False) self.__save_model(is_train=False) self.evaluate() def evaluate(self): """Restore a model from the latest checkpoint files and then evaluate it.""" self.__restore_model(is_train=False) if FLAGS.factory_mode: tmp_image = scipy.misc.imread(FLAGS.data_dir_local + "/images/" + FLAGS.image_name) x, y, z = tmp_image.shape print(tmp_image.shape) size_low = FLAGS.input_size size_high = FLAGS.sr_scale * size_low coordx = x // size_low coordy = y // size_low nb_iters = int( np.ceil(float(coordy * coordx) / FLAGS.batch_size_eval)) outputs = [] # outputs_bic = [] image = np.zeros([size_high * coordx, size_high * coordy, 3], dtype=np.uint8) # image_bic = np.zeros([size_high * coordx, size_high * coordy, 3], dtype=np.uint8) print(image.shape) print(nb_iters) for i in range(nb_iters): output = self.sess_eval.run(self.factory_op) for img in output[0]: outputs.append(img) print(np.array(outputs).shape) index = 0 for i in range(coordx): for j in range(coordy): image[i * size_high:(i + 1) * size_high, j * size_high:(j + 1) * size_high, :] = np.array( outputs[index]) index += 1 out = Image.fromarray(image, 'RGB') out.save('out_example/' + 'output.jpg') return nb_iters = int( np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval)) eval_rslts = np.zeros((nb_iters, len(self.eval_op))) # print("nb_iters: ", nb_iters) for idx_iter in range(nb_iters): eval_rslts[idx_iter] = self.sess_eval.run(self.eval_op) # eval_psnr = sorted(eval_rslts[:, 1]) # for idx in range(nb_iters): # print(eval_psnr[idx]) for idx, name in enumerate(self.eval_op_names): tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx]))) t = time.time() for idx_iter in range(nb_iters): _ = self.sess_eval.run(self.time_op) t = time.time() - t images, outputs, labels = self.sess_eval.run(self.out_op) # print(labels[0]) output_size = FLAGS.sr_scale * FLAGS.input_size for i in range(min(8, FLAGS.batch_size_eval)): img_bic = scipy.misc.imresize(images[i], (output_size, output_size), 'bicubic') img_bic = np.clip(img_bic, 0, 255) img_bic = np.array(img_bic, np.uint8) img_bic = Image.fromarray(img_bic, 'RGB') img = Image.fromarray(images[i], 'RGB') out = Image.fromarray(outputs[i], 'RGB') label = Image.fromarray(labels[i], 'RGB') img_bic.save(('out_example/' + str(i) + 'bic.jpg')) img.save('out_example/' + str(i) + 'image.jpg') out.save('out_example/' + str(i) + 'output.jpg') label.save('out_example/' + str(i) + 'label.jpg') tf.logging.info('time = %.4e' % (t / FLAGS.nb_smpls_eval)) txt = open("log.txt", "a") l = ["full"] l += [self.model_name] # for idx, name in enumerate(self.eval_op_names): # tmp = np.mean(eval_rslts[:, 1]) # l += ["PSNR: " + str(tmp)] for idx, name in enumerate(self.eval_op_names): tmp = np.mean(eval_rslts[:, idx]) l += [name + ": " + str(tmp)] l += ["eval_batch_size: " + str(FLAGS.batch_size_eval)] l += ["time/pic: " + str(t / FLAGS.nb_smpls_eval)] txt.write(str(l)) txt.write('\n') txt.close() def __build(self, is_train): # pylint: disable=too-many-locals """Build the training / evaluation graph. Args: * is_train: whether to create the training graph """ with tf.Graph().as_default(): # TensorFlow session config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train( ) if is_train else self.build_dataset_eval() images, labels = iterator.get_next() tf.add_to_collection('images_final', images) # model definition - distilled model if self.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - primary model with tf.variable_scope(self.model_scope): # forward pass logits = self.forward_train( images) if is_train else self.forward_eval(images) # out = self.forward_eval(images) tf.add_to_collection('logits_final', logits) # loss & extra evalution metrics loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if self.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) tf.summary.scalar('loss', loss) for key, value in metrics.items(): tf.summary.scalar(key, value) # optimizer & gradients if is_train: self.global_step = tf.train.get_or_create_global_step() lrn_rate, self.nb_iters_train = setup_lrn_rate( self.global_step, self.model_name, self.dataset_name) optimizer = tf.train.MomentumOptimizer( lrn_rate, FLAGS.momentum) # optimizer = tf.train.AdamOptimizer(lrn_rate) if FLAGS.enbl_multi_gpu: optimizer = mgw.DistributedOptimizer(optimizer) grads = optimizer.compute_gradients( loss, self.trainable_vars) # TF operations & model saver if is_train: self.sess_train = sess with tf.control_dependencies(self.update_ops): self.train_op = optimizer.apply_gradients( grads, global_step=self.global_step) self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss] + list(metrics.values()) self.log_op_names = ['lr', 'loss'] + list(metrics.keys()) self.init_op = tf.variables_initializer(self.vars) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) self.saver_train = tf.train.Saver(self.vars) else: self.sess_eval = sess self.factory_op = [tf.cast(logits, tf.uint8)] self.time_op = [logits] self.out_op = [ tf.cast(images, tf.uint8), tf.cast(logits, tf.uint8), tf.cast(labels, tf.uint8) ] self.eval_op = [loss] + list(metrics.values()) self.eval_op_names = ['loss'] + list(metrics.keys()) self.saver_eval = tf.train.Saver(self.vars) def __save_model(self, is_train): """Save the model to checkpoint files for training or evaluation. Args: * is_train: whether to save a model for training """ if is_train: save_path = self.saver_train.save(self.sess_train, FLAGS.save_path, self.global_step) else: save_path = self.saver_eval.save(self.sess_eval, FLAGS.save_path_eval) tf.logging.info('model saved to ' + save_path) def __restore_model(self, is_train): """Restore a model from the latest checkpoint files. Args: * is_train: whether to restore a model for training """ save_path = tf.train.latest_checkpoint(os.path.dirname( FLAGS.save_path)) if is_train: self.saver_train.restore(self.sess_train, save_path) else: self.saver_eval.restore(self.sess_eval, save_path) tf.logging.info('model restored from ' + save_path) def __monitor_progress(self, summary, log_rslt, idx_iter, time_step): """Monitor the training progress. Args: * summary: summary protocol buffer * log_rslt: logging operations' results * idx_iter: index of the training iteration * time_step: time step between two summary operations """ # write summaries for TensorBoard visualization self.sm_writer.add_summary(summary, idx_iter) # compute the training speed speed = FLAGS.batch_size * FLAGS.summ_step / time_step if FLAGS.enbl_multi_gpu: speed *= mgw.size() # display monitored statistics log_str = ' | '.join([ '%s = %.4e' % (name, value) for name, value in zip(self.log_op_names, log_rslt) ]) tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' % (idx_iter + 1, log_str, speed))
class UniformQuantLearner(AbstractLearner): # pylint: disable=too-many-instance-attributes """ Uniform quantization for weights and activations """ def __init__(self, sm_writer, model_helper): # class-independent initialization super(UniformQuantLearner, self).__init__(sm_writer, model_helper) # class-dependent initialization if FLAGS.enbl_dst: self.helper_dst = DistillationHelper(sm_writer, model_helper, self.mpi_comm) # initialize class attributes self.ops = {} self.bit_placeholders = {} self.statistics = {} self.__build_train() # for train self.__build_eval() # for eval if self.is_primary_worker('local'): self.download_model() # pre-trained model is required self.auto_barrier() # determine the optimal policy. bit_optimizer = BitOptimizer(self.dataset_name, self.weights, self.statistics, self.bit_placeholders, self.ops, self.layerwise_tune_list, self.sess_train, self.sess_eval, self.saver_train, self.saver_eval, self.auto_barrier) self.optimal_w_bit_list, self.optimal_a_bit_list = bit_optimizer.run() self.auto_barrier() def train(self): # initialization self.sess_train.run(self.ops['init']) # mgw_size = int(mgw.size()) if FLAGS.enbl_multi_gpu else 1 total_iters = self.finetune_steps if FLAGS.enbl_warm_start: self.__restore_model( is_train=True) # use the latest model for warm start self.auto_barrier() if FLAGS.enbl_multi_gpu: self.sess_train.run(self.ops['bcast']) time_prev = timer() # build the quantization bits feed_dict = { self.bit_placeholders['w_train']: self.optimal_w_bit_list, self.bit_placeholders['a_train']: self.optimal_a_bit_list } for idx_iter in range(total_iters): # train the model if (idx_iter + 1) % FLAGS.summ_step != 0: self.sess_train.run(self.ops['train'], feed_dict=feed_dict) else: _, summary, log_rslt = self.sess_train.run( [self.ops['train'], self.ops['summary'], self.ops['log']], feed_dict=feed_dict) time_prev = self.__monitor_progress(summary, log_rslt, time_prev, idx_iter) # save & evaluate the model at certain steps if (idx_iter + 1) % FLAGS.save_step == 0: self.__save_model() self.evaluate() self.auto_barrier() # save the final model self.__save_model() self.evaluate() def evaluate(self): # early break for non-primary workers if not self.is_primary_worker(): return is_openpose = self.dataset_name == 'coco2017-pose' # evaluate the model self.__restore_model(is_train=False) losses, accuracies = [], [] nb_iters = int( np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval)) # build the quantization bits feed_dict = { self.bit_placeholders['w_eval']: self.optimal_w_bit_list, self.bit_placeholders['a_eval']: self.optimal_a_bit_list } for _ in range(nb_iters): eval_rslt = self.sess_eval.run(self.ops['eval'], feed_dict=feed_dict) losses.append(eval_rslt[0]) accuracies.append(eval_rslt[1]) tf.logging.info('loss: {}'.format(np.mean(np.array(losses)))) if not is_openpose: tf.logging.info('accuracy: {}'.format(np.mean( np.array(accuracies)))) tf.logging.info("Optimal Weight Quantization:{}".format( self.optimal_w_bit_list)) if FLAGS.uql_use_buckets: bucket_storage = self.sess_eval.run(self.ops['bucket_storage'], feed_dict=feed_dict) self.__show_bucket_storage(bucket_storage) if is_openpose and FLAGS.calculate_map: from examples.openpose_eval_helper import calculate_map tensor_image = self.sess_eval.graph.get_tensor_by_name( 'model/MobilenetV2/input:0') tensor_output = self.sess_eval.graph.get_tensor_by_name( 'model/Openpose/concat_stage7:0') calculate_map(lambda img: self.sess_eval.run([tensor_output], feed_dict={ tensor_image: img, **feed_dict })[0]) def __build_train(self): with tf.Graph().as_default(): # TensorFlow session config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) self.sess_train = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() images.set_shape((FLAGS.batch_size, images.shape[1], images.shape[2], images.shape[3])) # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_train, images) # model definition with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE): # forward pass logits = self.forward_train(images) self.weights = [ v for v in self.trainable_vars if 'kernel' in v.name or 'weight' in v.name ] if not FLAGS.uql_quantize_all_layers: self.weights = self.weights[1:-1] self.statistics['num_weights'] = \ [tf.reshape(v, [-1]).shape[0].value for v in self.weights] self.__quantize_train_graph() # loss & accuracy loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if self.dataset_name == 'cifar_10': acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.) elif self.dataset_name == 'ilsvrc_12': acc_top1, acc_top5 = metrics['acc_top1'], metrics[ 'acc_top5'] elif self.dataset_name == 'coco2017-pose': total_loss = metrics['total_loss_all_layers'] total_loss_ll_paf = metrics['total_loss_last_layer_paf'] total_loss_ll_heat = metrics['total_loss_last_layer_heat'] total_loss_ll = metrics['total_loss_last_layer'] else: raise ValueError("Unrecognized dataset name") model_loss = loss if FLAGS.enbl_dst: dst_loss = self.helper_dst.calc_loss(logits, logits_dst) loss += dst_loss tf.summary.scalar('dst_loss', dst_loss) tf.summary.scalar('model_loss', model_loss) tf.summary.scalar('loss', loss) if self.dataset_name == 'coco2017-pose': tf.summary.scalar('total_loss_all_layers', total_loss) tf.summary.scalar('total_loss_last_layer_paf', total_loss_ll_paf) tf.summary.scalar('total_loss_last_layer_heat', total_loss_ll_heat) tf.summary.scalar('total_loss_last_layer', total_loss_ll) else: tf.summary.scalar('acc_top1', acc_top1) tf.summary.scalar('acc_top5', acc_top5) self.saver_train = tf.train.Saver(self.vars) self.ft_step = tf.get_variable('finetune_step', shape=[], dtype=tf.int32, trainable=False) # optimizer & gradients init_lr, bnds, decay_rates, self.finetune_steps = setup_bnds_decay_rates( self.model_name, self.dataset_name) lrn_rate = tf.train.piecewise_constant( self.ft_step, [i for i in bnds], [init_lr * decay_rate for decay_rate in decay_rates]) # optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) optimizer = tf.train.AdamOptimizer(learning_rate=lrn_rate) if FLAGS.enbl_multi_gpu: optimizer = mgw.DistributedOptimizer(optimizer) grads = optimizer.compute_gradients(loss, self.trainable_vars) # sm write graph self.sm_writer.add_graph(self.sess_train.graph) with tf.control_dependencies(self.update_ops): self.ops['train'] = optimizer.apply_gradients( grads, global_step=self.ft_step) self.ops['summary'] = tf.summary.merge_all() if FLAGS.enbl_dst: self.ops['log'] = [ lrn_rate, dst_loss, model_loss, loss, ] else: self.ops['log'] = [ lrn_rate, model_loss, loss, ] if self.dataset_name == 'coco2017-pose': self.ops['log'] += [ total_loss, total_loss_ll_paf, total_loss_ll_heat, total_loss_ll ] else: self.ops['log'] += [acc_top1, acc_top5] self.ops['reset_ft_step'] = tf.assign( self.ft_step, tf.constant(0, dtype=tf.int32)) self.ops['init'] = tf.global_variables_initializer() self.ops['bcast'] = mgw.broadcast_global_variables( 0) if FLAGS.enbl_multi_gpu else None self.saver_quant = tf.train.Saver(self.vars) def __build_eval(self): with tf.Graph().as_default(): # TensorFlow session # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) self.sess_eval = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_eval() images, labels = iterator.get_next() # images.set_shape((FLAGS.batch_size, images.shape[1], images.shape[2], images.shape[3])) images.set_shape((FLAGS.batch_size_eval, images.shape[1], images.shape[2], images.shape[3])) self.images_eval = images # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_eval, images) # model definition with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE): # forward pass logits = self.forward_eval(images) self.__quantize_eval_graph() # loss & accuracy loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if self.dataset_name == 'cifar_10': acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.) elif self.dataset_name == 'ilsvrc_12': acc_top1, acc_top5 = metrics['acc_top1'], metrics[ 'acc_top5'] elif self.dataset_name == 'coco2017-pose': total_loss = metrics['total_loss_all_layers'] total_loss_ll_paf = metrics['total_loss_last_layer_paf'] total_loss_ll_heat = metrics['total_loss_last_layer_heat'] total_loss_ll = metrics['total_loss_last_layer'] else: raise ValueError("Unrecognized dataset name") if FLAGS.enbl_dst: dst_loss = self.helper_dst.calc_loss(logits, logits_dst) loss += dst_loss # TF operations & model saver if self.dataset_name == 'coco2017-pose': self.ops['eval'] = [ loss, total_loss, total_loss_ll_paf, total_loss_ll_heat, total_loss_ll ] else: self.ops['eval'] = [loss, acc_top1, acc_top5] self.saver_eval = tf.train.Saver(self.vars) def __quantize_train_graph(self): """ Insert quantization nodes to the training graph. """ uni_quant = UniformQuantization(self.sess_train, FLAGS.uql_bucket_size, FLAGS.uql_use_buckets, FLAGS.uql_bucket_type) # Find Conv2d Op matmul_ops = uni_quant.search_matmul_op(FLAGS.uql_quantize_all_layers) act_ops = uni_quant.search_activation_op() self.statistics['nb_matmuls'] = len(matmul_ops) self.statistics['nb_activations'] = len(act_ops) # Replace Conv2d Op with quantized weights matmul_op_names = [op.name for op in matmul_ops] act_op_names = [op.name for op in act_ops] # build the placeholder for self.bit_placeholders['w_train'] = tf.placeholder( tf.int64, shape=[self.statistics['nb_matmuls']], name="w_bit_list") self.bit_placeholders['a_train'] = tf.placeholder( tf.int64, shape=[self.statistics['nb_activations']], name="a_bit_list") w_bit_dict_train = self.__build_quant_dict( matmul_op_names, self.bit_placeholders['w_train']) a_bit_dict_train = self.__build_quant_dict( act_op_names, self.bit_placeholders['a_train']) uni_quant.insert_quant_op_for_weights(w_bit_dict_train) uni_quant.insert_quant_op_for_activations(a_bit_dict_train) # add layerwise finetuning. TODO: working not very well self.layerwise_tune_list = uni_quant.get_layerwise_tune_op(self.weights) \ if FLAGS.uql_enbl_rl_layerwise_tune else (None, None) def __quantize_eval_graph(self): """ Insert quantization nodes to the evaluation graph. """ uni_quant = UniformQuantization(self.sess_eval, FLAGS.uql_bucket_size, FLAGS.uql_use_buckets, FLAGS.uql_bucket_type) # Find matmul ops matmul_ops = uni_quant.search_matmul_op(FLAGS.uql_quantize_all_layers) act_ops = uni_quant.search_activation_op() assert self.statistics['nb_matmuls'] == len(matmul_ops), \ 'the length of matmul_ops on train and eval graphs does not match' assert self.statistics['nb_activations'] == len(act_ops), \ 'the length of act_ops on train and eval graphs does not match' # Replace Conv2d Op with quantized weights matmul_op_names = [op.name for op in matmul_ops] act_op_names = [op.name for op in act_ops] # build the placeholder for eval self.bit_placeholders['w_eval'] = tf.placeholder( tf.int64, shape=[self.statistics['nb_matmuls']], name="w_bit_list") self.bit_placeholders['a_eval'] = tf.placeholder( tf.int64, shape=[self.statistics['nb_activations']], name="a_bit_list") w_bit_dict_eval = self.__build_quant_dict( matmul_op_names, self.bit_placeholders['w_eval']) a_bit_dict_eval = self.__build_quant_dict( act_op_names, self.bit_placeholders['a_eval']) uni_quant.insert_quant_op_for_weights(w_bit_dict_eval) uni_quant.insert_quant_op_for_activations(a_bit_dict_eval) self.ops['bucket_storage'] = uni_quant.bucket_storage def __save_model(self): # early break for non-primary workers if not self.is_primary_worker(): return save_quant_model_path = self.saver_quant.save( self.sess_train, FLAGS.uql_save_quant_model_path, self.ft_step) # tf.logging.info('full precision model saved to ' + save_path) tf.logging.info('quantized model saved to ' + save_quant_model_path) def __restore_model(self, is_train): if is_train: save_path = tf.train.latest_checkpoint( os.path.dirname(FLAGS.save_path)) save_dir = os.path.dirname(save_path) for item in os.listdir(save_dir): print('Print directory: ' + item) self.saver_train.restore(self.sess_train, save_path) else: save_path = tf.train.latest_checkpoint( os.path.dirname(FLAGS.uql_save_quant_model_path)) self.saver_eval.restore(self.sess_eval, save_path) tf.logging.info('model restored from ' + save_path) def __monitor_progress(self, summary, log_rslt, time_prev, idx_iter): # early break for non-primary workers if not self.is_primary_worker(): return None # write summaries for TensorBoard visualization self.sm_writer.add_summary(summary, idx_iter) # display monitored statistics speed = FLAGS.batch_size * FLAGS.summ_step / (timer() - time_prev) if FLAGS.enbl_multi_gpu: speed *= mgw.size() # NOTE: for cifar-10, acc_top5 is 0. if self.dataset_name == 'coco2017-pose': if FLAGS.enbl_dst: lrn_rate, dst_loss, model_loss, loss, total_loss, total_loss_ll_paf, total_loss_ll_heat, total_loss_ll = log_rslt[: 8] tf.logging.info( 'iter #%d: lr = %e | dst_loss = %.4f | model_loss = %.4f | loss = %.4f | ll_paf = %.4f | ll_heat = %.4f | ll = %.4f | speed = %.2f pics / sec' % (idx_iter + 1, lrn_rate, dst_loss, model_loss, loss, total_loss_ll_paf, total_loss_ll_heat, total_loss_ll, speed)) else: lrn_rate, model_loss, loss, total_loss, total_loss_ll_paf, total_loss_ll_heat, total_loss_ll = log_rslt[: 7] tf.logging.info( 'iter #%d: lr = %e | model_loss = %.4f | loss = %.4f | ll_paf = %.4f | ll_heat = %.4f | ll = %.4f | speed = %.2f pics / sec' % (idx_iter + 1, lrn_rate, model_loss, loss, total_loss_ll_paf, total_loss_ll_heat, total_loss_ll, speed)) else: if FLAGS.enbl_dst: lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5 = log_rslt[: 6] tf.logging.info( 'iter #%d: lr = %e | dst_loss = %.4f | model_loss = %.4f | loss = %.4f | acc_top1 = %.4f | acc_top5 = %.4f | speed = %.2f pics / sec' % (idx_iter + 1, lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5, speed)) else: lrn_rate, model_loss, loss, acc_top1, acc_top5 = log_rslt[:5] tf.logging.info( 'iter #%d: lr = %e | model_loss = %.4f | loss = %.4f | acc_top1 = %.4f | acc_top5 = %.4f | speed = %.2f pics / sec' % (idx_iter + 1, lrn_rate, model_loss, loss, acc_top1, acc_top5, speed)) return timer() def __show_bucket_storage(self, bucket_storage): # show the bucket storage and ratios weight_storage = sum(self.statistics['num_weights']) * FLAGS.uql_weight_bits \ if not FLAGS.uql_enbl_rl_agent else sum(self.statistics['num_weights']) * FLAGS.uql_equivalent_bits tf.logging.info( 'bucket storage: %d bit / %.3f kb | weight storage: %d bit / %.3f kb | ratio: %.3f' % (bucket_storage, bucket_storage / (8. * 1024.), weight_storage, weight_storage / (8. * 1024.), bucket_storage * 1. / weight_storage)) @staticmethod def __build_quant_dict(keys, values): """ Bind keys and values to dictionaries. Args: * keys: A list of op_names * values: A Tensor with len(op_names) elements Returns: * dict: (key, value) for weight name and quant bits respectively """ dict_ = {} for (idx, v) in enumerate(keys): dict_[v] = values[idx] return dict_