def __build_eval(self): """Build the evaluation graph.""" with tf.Graph().as_default() as graph: # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member config.gpu_options.allow_growth = True # pylint: disable=no-member self.sess_eval = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_eval() images, labels = iterator.get_next() # model definition - uniform quantized model - part 1 with tf.variable_scope(self.model_scope_quan): logits = self.forward_eval(images) if not isinstance(logits, dict): outputs = tf.nn.softmax(logits) else: outputs = tf.nn.softmax(logits['cls_pred']) tf.contrib.quantize.experimental_create_eval_graph( weight_bits=FLAGS.uqtf_weight_bits, activation_bits=FLAGS.uqtf_activation_bits, scope=self.model_scope_quan) for node_name in self.unquant_node_names: insert_quant_op(graph, node_name, is_train=False) vars_quan = get_vars_by_scope(self.model_scope_quan) # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_eval, images) # model definition - uniform quantized model -part 2 with tf.variable_scope(self.model_scope_quan): # loss & extra evaluation metrics loss, metrics = self.calc_loss(labels, logits, vars_quan['trainable']) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) # TF operations for evaluation vars_quan = get_vars_by_scope(self.model_scope_quan) self.eval_op = [loss] + list(metrics.values()) self.eval_op_names = ['loss'] + list(metrics.keys()) self.outputs_eval = logits self.saver_quan_eval = tf.train.Saver(vars_quan['all']) # add input & output tensors to certain collections if not isinstance(images, dict): tf.add_to_collection('images_final', images) else: tf.add_to_collection('images_final', images['image']) if not isinstance(logits, dict): tf.add_to_collection('logits_final', logits) else: tf.add_to_collection('logits_final', logits['cls_pred'])
def __build_eval(self, model_helper): """Build the evaluation graph for the 'optimal' protocol. Args: * model_helper: model helper with definitions of model & dataset """ with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member self.sess_eval = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): __, iterator = model_helper.build_dataset_train( enbl_trn_val_split=True) images, labels = iterator.get_next() # model definition - weight sparsified network with tf.variable_scope(self.model_scope_prnd): logits = model_helper.forward_eval(images) vars_prnd = get_vars_by_scope(self.model_scope_prnd) self.loss_eval, self.metrics_eval = \ model_helper.calc_loss(labels, logits, vars_prnd['trainable']) self.saver_prnd_eval = tf.train.Saver(vars_prnd['all'])
def __build_eval(self): with tf.Graph().as_default(): # TensorFlow session # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) self.sess_eval = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_eval() images, labels = iterator.get_next() # images.set_shape((FLAGS.batch_size, images.shape[1], images.shape[2], images.shape[3])) images.set_shape((FLAGS.batch_size_eval, images.shape[1], images.shape[2], images.shape[3])) self.images_eval = images # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_eval, images) # model definition with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE): # forward pass logits = self.forward_eval(images) self.__quantize_eval_graph() # loss & accuracy loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if self.dataset_name == 'cifar_10': acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.) elif self.dataset_name == 'ilsvrc_12': acc_top1, acc_top5 = metrics['acc_top1'], metrics[ 'acc_top5'] elif self.dataset_name == 'coco2017-pose': total_loss = metrics['total_loss_all_layers'] total_loss_ll_paf = metrics['total_loss_last_layer_paf'] total_loss_ll_heat = metrics['total_loss_last_layer_heat'] total_loss_ll = metrics['total_loss_last_layer'] else: raise ValueError("Unrecognized dataset name") if FLAGS.enbl_dst: dst_loss = self.helper_dst.calc_loss(logits, logits_dst) loss += dst_loss # TF operations & model saver if self.dataset_name == 'coco2017-pose': self.ops['eval'] = [ loss, total_loss, total_loss_ll_paf, total_loss_ll_heat, total_loss_ll ] else: self.ops['eval'] = [loss, acc_top1, acc_top5] self.saver_eval = tf.train.Saver(self.vars)
def __build_eval(self): """Build the evaluation graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member self.sess_eval = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_eval() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_eval, images) # model definition - channel-pruned model with tf.variable_scope(self.model_scope_prnd): # loss & extra evaluation metrics logits = self.forward_eval(images) vars_prnd = get_vars_by_scope(self.model_scope_prnd) loss, metrics = self.calc_loss(labels, logits, vars_prnd['trainable']) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) # overall pruning ratios of trainable & maskable variables pr_trainable = calc_prune_ratio(vars_prnd['trainable']) pr_maskable = calc_prune_ratio(vars_prnd['maskable']) # TF operations for evaluation self.factory_op = [tf.cast(logits, tf.uint8)] self.time_op = [logits] self.out_op = [ tf.cast(images, tf.uint8), tf.cast(logits, tf.uint8), tf.cast(labels, tf.uint8) ] self.eval_op = [loss, pr_trainable, pr_maskable] + list( metrics.values()) self.eval_op_names = ['loss', 'pr_trn', 'pr_msk'] + list( metrics.keys()) self.saver_prnd_eval = tf.train.Saver(vars_prnd['all']) # add input & output tensors to certain collections tf.add_to_collection('images_final', images) tf.add_to_collection('logits_final', logits)
def create_session(): """Create a TensorFlow session. Return: * sess: TensorFlow session """ # create a TensorFlow session config = tf.ConfigProto() config.gpu_options.visible_device_list = str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member config.gpu_options.allow_growth = True # pylint: disable=no-member sess = tf.Session(config=config) return sess
def is_primary_worker(scope='global'): """Check whether is the primary worker of all nodes (global) or the current node (local). Args: * scope: check scope ('global' OR 'local') Returns: * flag: whether is the primary worker """ if scope == 'global': return True if not FLAGS.enbl_multi_gpu else mgw.rank() == 0 elif scope == 'local': return True if not FLAGS.enbl_multi_gpu else mgw.local_rank() == 0 else: raise ValueError('unrecognized worker scope: ' + scope)
def __build_eval(self): """Build the evaluation graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() if FLAGS.enbl_multi_gpu: config.gpu_options.visible_device_list = str(mgw.local_rank()) # pylint: disable=no-member else: config.gpu_options.visible_device_list = '0' # pylint: disable=no-member self.sess_eval = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_eval() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_eval, images) # model definition - weight-sparsified model with tf.variable_scope(self.model_scope): # loss & extra evaluation metrics logits = self.forward_eval(images) loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) # overall pruning ratios of trainable & maskable variables pr_trainable = calc_prune_ratio(self.trainable_vars) pr_maskable = calc_prune_ratio(self.maskable_vars) # TF operations for evaluation self.eval_op = [loss, pr_trainable, pr_maskable] + list( metrics.values()) self.eval_op_names = ['loss', 'pr_trn', 'pr_msk'] + list( metrics.keys()) self.saver_eval = tf.train.Saver(self.vars)
def __build_pruned_evaluate_model(self, path=None): ''' build a evaluation model from pruned model ''' # early break for non-primary workers if not self.__is_primary_worker(): return if path is None: path = FLAGS.save_path if not tf.train.checkpoint_exists(path): return with tf.Graph().as_default(): config = tf.ConfigProto() config.gpu_options.visible_device_list = str( # pylint: disable=no-member mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) self.sess_eval = tf.Session(config=config) self.saver_eval = tf.train.import_meta_graph(path + '.meta') self.saver_eval.restore(self.sess_eval, path) eval_logits = tf.get_collection('logits')[0] tf.add_to_collection('logits_final', eval_logits) eval_images = tf.get_collection('eval_images')[0] tf.add_to_collection('images_final', eval_images) eval_labels = tf.get_collection('eval_labels')[0] mem_images = tf.get_collection('mem_images')[0] mem_labels = tf.get_collection('mem_labels')[0] self.sess_eval.close() graph_editor.reroute_ts(eval_images, mem_images) graph_editor.reroute_ts(eval_labels, mem_labels) self.sess_eval = tf.Session(config=config) self.saver_eval.restore(self.sess_eval, path) trainable_vars = self.trainable_vars loss, accuracy = self.calc_loss(eval_labels, eval_logits, trainable_vars) self.eval_op = [loss] + list(accuracy.values()) self.sm_writer.add_graph(self.sess_eval.graph)
def __build_minimal(self, model_helper): """Build the minimal graph for 'uniform' & 'heurist' protocols. Args: * model_helper: model helper with definitions of model & dataset """ with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member self.sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = model_helper.build_dataset_train() images, __ = iterator.get_next() # model definition - full-precision network with tf.variable_scope(self.model_scope_full): __ = model_helper.forward_eval( images) # DO NOT USE forward_train() HERE!!! self.vars_full = get_vars_by_scope(self.model_scope_full)
def __build_train(self): with tf.Graph().as_default(): # TensorFlow session config = tf.ConfigProto() config.gpu_options.visible_device_list = str(mgw.local_rank() \ if FLAGS.enbl_multi_gpu else 0) self.sess_train = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() images.set_shape((FLAGS.batch_size, images.shape[1], images.shape[2], images.shape[3])) # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_train, images) # model definition with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE): # forward pass logits = self.forward_train(images) self.weights = [ v for v in self.trainable_vars if 'kernel' in v.name or 'weight' in v.name ] if not FLAGS.uql_quantize_all_layers: self.weights = self.weights[1:-1] self.statistics['num_weights'] = \ [tf.reshape(v, [-1]).shape[0].value for v in self.weights] self.__quantize_train_graph() # loss & accuracy loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if self.dataset_name == 'cifar_10': acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.) elif self.dataset_name == 'ilsvrc_12': acc_top1, acc_top5 = metrics['acc_top1'], metrics[ 'acc_top5'] else: raise ValueError("Unrecognized dataset name") model_loss = loss if FLAGS.enbl_dst: dst_loss = self.helper_dst.calc_loss(logits, logits_dst) loss += dst_loss tf.summary.scalar('dst_loss', dst_loss) tf.summary.scalar('model_loss', model_loss) tf.summary.scalar('loss', loss) tf.summary.scalar('acc_top1', acc_top1) tf.summary.scalar('acc_top5', acc_top5) self.saver_train = tf.train.Saver(self.vars) self.ft_step = tf.get_variable('finetune_step', shape=[], dtype=tf.int32, trainable=False) # optimizer & gradients init_lr, bnds, decay_rates, self.finetune_steps = \ setup_bnds_decay_rates(self.model_name, self.dataset_name) lrn_rate = tf.train.piecewise_constant( self.ft_step, [i for i in bnds], [init_lr * decay_rate for decay_rate in decay_rates]) # optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) optimizer = tf.train.AdamOptimizer(learning_rate=lrn_rate) if FLAGS.enbl_multi_gpu: optimizer = mgw.DistributedOptimizer(optimizer) grads = optimizer.compute_gradients(loss, self.trainable_vars) # sm write graph self.sm_writer.add_graph(self.sess_train.graph) with tf.control_dependencies(self.update_ops): self.ops['train'] = optimizer.apply_gradients( grads, global_step=self.ft_step) self.ops['summary'] = tf.summary.merge_all() if FLAGS.enbl_dst: self.ops['log'] = [ lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5 ] else: self.ops['log'] = [ lrn_rate, model_loss, loss, acc_top1, acc_top5 ] self.ops['reset_ft_step'] = tf.assign( self.ft_step, tf.constant(0, dtype=tf.int32)) self.ops['init'] = tf.global_variables_initializer() self.ops['bcast'] = mgw.broadcast_global_variables( 0) if FLAGS.enbl_multi_gpu else None self.saver_quant = tf.train.Saver(self.vars)
def export_tflite_model(input_coll, output_coll, images_shape, images_name): """Export a *.tflite model from checkpoint files. Args: * input_coll: input collection's name * output_coll: output collection's name Returns: * unquant_node_name: unquantized activation node name (None if not found) """ # remove previously generated *.pb & *.tflite models model_dir = os.path.dirname(FLAGS.uqtf_save_path_probe_eval) idx_worker = mgw.local_rank() if FLAGS.enbl_multi_gpu else 0 pb_path = os.path.join(model_dir, 'model_%d.pb' % idx_worker) tflite_path = os.path.join(model_dir, 'model_%d.tflite' % idx_worker) if os.path.exists(pb_path): os.remove(pb_path) if os.path.exists(tflite_path): os.remove(tflite_path) # convert checkpoint files to a *.pb model images_name_ph = 'images' with tf.Graph().as_default() as graph: # create a TensorFlow session sess = create_session() # restore the graph with inputs replaced ckpt_path = tf.train.latest_checkpoint(model_dir) meta_path = ckpt_path + '.meta' images = tf.placeholder(tf.float32, shape=images_shape, name=images_name_ph) saver = tf.train.import_meta_graph(meta_path, input_map={images_name: images}) saver.restore(sess, ckpt_path) # obtain input & output nodes net_inputs = tf.get_collection(input_coll) net_logits = tf.get_collection(output_coll)[0] net_outputs = [tf.nn.softmax(net_logits)] for node in net_inputs: tf.logging.info('inputs: {} / {}'.format(node.name, node.shape)) for node in net_outputs: tf.logging.info('outputs: {} / {}'.format(node.name, node.shape)) # write the original grpah to *.pb file graph_def = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [node.name.replace(':0', '') for node in net_outputs]) tf.train.write_graph(graph_def, model_dir, os.path.basename(pb_path), as_text=False) assert os.path.exists(pb_path), 'failed to generate a *.pb model' # convert the *.pb model to a *.tflite model and detect the unquantized activation node (if any) tf.logging.info(pb_path + ' -> ' + tflite_path) converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph( pb_path, [images_name_ph], [node.name.replace(':0', '') for node in net_outputs]) converter.inference_type = tf.lite.constants.QUANTIZED_UINT8 #lite_constants.QUANTIZED_UINT8 converter.quantized_input_stats = {images_name_ph: (0., 1.)} unquant_node_name = None try: tflite_model = converter.convert() with open(tflite_path, 'wb') as o_file: o_file.write(tflite_model) except Exception as err: err_msg = str(err) flag_str = 'tensorflow/contrib/lite/toco/tooling_util.cc:1634]' for sub_line in err_msg.split('\\n'): if flag_str in sub_line: sub_strs = sub_line.replace(',', ' ').split() unquant_node_name = sub_strs[sub_strs.index(flag_str) + 2] + ':0' break assert unquant_node_name is not None, 'unable to locate the unquantized node' return unquant_node_name
def __build(self, is_train): # pylint: disable=too-many-locals """Build the training / evaluation graph. Args: * is_train: whether to create the training graph """ with tf.Graph().as_default(): # TensorFlow session config = tf.ConfigProto() config.gpu_options.visible_device_list = str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() if is_train else self.build_dataset_eval() images, labels = iterator.get_next() tf.add_to_collection('images_final', images) # model definition - distilled model if self.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - primary model with tf.variable_scope(self.model_scope): # forward pass logits = self.forward_train(images) if is_train else self.forward_eval(images) tf.add_to_collection('logits_final', logits) # loss & extra evalution metrics loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if self.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) tf.summary.scalar('loss', loss) for key, value in metrics.items(): tf.summary.scalar(key, value) # optimizer & gradients if is_train: self.global_step = tf.train.get_or_create_global_step() lrn_rate, self.nb_iters_train = self.setup_lrn_rate(self.global_step) optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) if FLAGS.enbl_multi_gpu: optimizer = mgw.DistributedOptimizer(optimizer) grads = optimizer.compute_gradients(loss, self.trainable_vars) # TF operations & model saver if is_train: self.sess_train = sess with tf.control_dependencies(self.update_ops): self.train_op = optimizer.apply_gradients(grads, global_step=self.global_step) self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss] + list(metrics.values()) self.log_op_names = ['lr', 'loss'] + list(metrics.keys()) self.init_op = tf.variables_initializer(self.vars) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) self.saver_train = tf.train.Saver(self.vars) else: self.sess_eval = sess self.eval_op = [loss] + list(metrics.values()) self.eval_op_names = ['loss'] + list(metrics.keys()) self.outputs_eval = logits self.saver_eval = tf.train.Saver(self.vars)
def __build_train(self): # pylint: disable=too-many-locals,too-many-statements """Build the training graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - full model with tf.variable_scope(self.model_scope_full): __ = self.forward_train(images) self.vars_full = get_vars_by_scope(self.model_scope_full) self.saver_full = tf.train.Saver(self.vars_full['all']) self.save_path_full = FLAGS.save_path # model definition - channel-pruned model with tf.variable_scope(self.model_scope_prnd): logits_prnd = self.forward_train(images) self.vars_prnd = get_vars_by_scope(self.model_scope_prnd) self.maskable_var_names = [var.name for var in self.vars_prnd['maskable']] self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all']) # loss & extra evaluation metrics loss_bsc, metrics = self.calc_loss(labels, logits_prnd, self.vars_prnd['trainable']) if not FLAGS.enbl_dst: loss_fnl = loss_bsc else: loss_fnl = loss_bsc + self.helper_dst.calc_loss(logits_prnd, logits_dst) tf.summary.scalar('loss_bsc', loss_bsc) tf.summary.scalar('loss_fnl', loss_fnl) for key, value in metrics.items(): tf.summary.scalar(key, value) # learning rate schedule self.global_step = tf.train.get_or_create_global_step() lrn_rate, self.nb_iters_train = self.setup_lrn_rate(self.global_step) # overall pruning ratios of trainable & maskable variables pr_trainable = calc_prune_ratio(self.vars_prnd['trainable']) pr_maskable = calc_prune_ratio(self.vars_prnd['maskable']) tf.summary.scalar('pr_trainable', pr_trainable) tf.summary.scalar('pr_maskable', pr_maskable) # create masks and corresponding operations for channel pruning self.masks = [] self.mask_deltas = [] self.mask_init_ops = [] self.mask_updt_ops = [] self.prune_ops = [] for idx, var in enumerate(self.vars_prnd['maskable']): name = '/'.join(var.name.split('/')[1:]).replace(':0', '_mask') self.masks += [tf.get_variable(name, initializer=tf.ones(var.shape), trainable=False)] name = '/'.join(var.name.split('/')[1:]).replace(':0', '_mask_delta') self.mask_deltas += [tf.placeholder(tf.float32, shape=var.shape, name=name)] self.mask_init_ops += [self.masks[idx].assign(tf.zeros(var.shape))] self.mask_updt_ops += [self.masks[idx].assign_add(self.mask_deltas[idx])] self.prune_ops += [var.assign(var * self.masks[idx])] # build extra losses for regression & discrimination self.reg_losses, self.dis_losses, self.idxs_layer_to_block = \ self.__build_extra_losses(labels) self.dis_losses += [loss_bsc] # append discrimination-aware loss for the last block self.nb_layers = len(self.reg_losses) self.nb_blocks = len(self.dis_losses) for idx, reg_loss in enumerate(self.reg_losses): tf.summary.scalar('reg_loss_%d' % idx, reg_loss) for idx, dis_loss in enumerate(self.dis_losses): tf.summary.scalar('dis_loss_%d' % idx, dis_loss) # obtain the full list of trainable variables & update operations self.vars_all = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_prnd) self.trainable_vars_all = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.model_scope_prnd) self.update_ops_all = tf.get_collection( tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_prnd) # TF operations for initializing the channel-pruned model init_ops = [] with tf.control_dependencies([tf.variables_initializer(self.vars_all)]): for var_full, var_prnd in zip(self.vars_full['all'], self.vars_prnd['all']): init_ops += [var_prnd.assign(var_full)] self.init_op = tf.group(init_ops) # TF operations for layer-wise, block-wise, and whole-network fine-tuning self.layer_train_ops, self.layer_init_opt_ops, self.grad_norms = self.__build_layer_ops() self.block_train_ops, self.block_init_opt_ops = self.__build_block_ops() self.train_op, self.init_opt_op = self.__build_network_ops(loss_fnl, lrn_rate) # TF operations for logging & summarizing self.sess_train = sess self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss_fnl, pr_trainable, pr_maskable] + list(metrics.values()) self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list(metrics.keys()) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0)
def __build_pruned_train_model(self, path=None, finetune=False): # pylint: disable=too-many-locals ''' build a training model from pruned model ''' if path is None: path = FLAGS.save_path with tf.Graph().as_default(): config = tf.ConfigProto() config.gpu_options.visible_device_list = str(# pylint: disable=no-member mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) self.sess_train = tf.Session(config=config) self.saver_train = tf.train.import_meta_graph(path + '.meta') self.saver_train.restore(self.sess_train, path) logits = tf.get_collection('logits')[0] train_images = tf.get_collection('train_images')[0] train_labels = tf.get_collection('train_labels')[0] mem_images = tf.get_collection('mem_images')[0] mem_labels = tf.get_collection('mem_labels')[0] self.sess_train.close() graph_editor.reroute_ts(train_images, mem_images) graph_editor.reroute_ts(train_labels, mem_labels) self.sess_train = tf.Session(config=config) self.saver_train.restore(self.sess_train, path) trainable_vars = self.trainable_vars loss, accuracy = self.calc_loss(train_labels, logits, trainable_vars) self.accuracy_keys = list(accuracy.keys()) if FLAGS.enbl_dst: logits_dst = self.learner_dst.calc_logits(self.sess_train, train_images) loss += self.learner_dst.calc_loss(logits, logits_dst) tf.summary.scalar('loss', loss) for key in accuracy.keys(): tf.summary.scalar(key, accuracy[key]) self.summary_op = tf.summary.merge_all() global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, trainable=False) self.global_step = global_step lrn_rate, self.nb_iters_train = setup_lrn_rate( self.global_step, self.model_name, self.dataset_name) if finetune and not FLAGS.cp_retrain: mom_optimizer = tf.train.AdamOptimizer(FLAGS.cp_lrn_rate_ft) self.log_op = [tf.constant(FLAGS.cp_lrn_rate_ft), loss, list(accuracy.values())] else: mom_optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) self.log_op = [lrn_rate, loss, list(accuracy.values())] if FLAGS.enbl_multi_gpu: optimizer = mgw.DistributedOptimizer(mom_optimizer) else: optimizer = mom_optimizer grads_origin = optimizer.compute_gradients(loss, trainable_vars) grads_pruned, masks = self.__calc_grads_pruned(grads_origin) with tf.control_dependencies(self.update_ops): self.train_op = optimizer.apply_gradients(grads_pruned, global_step=global_step) self.sm_writer.add_graph(tf.get_default_graph()) self.train_init_op = \ tf.initialize_variables(mom_optimizer.variables() + [global_step] + masks) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0)
def __build_train(self): # pylint: disable=too-many-locals """Build the training graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() if FLAGS.enbl_multi_gpu: config.gpu_options.visible_device_list = str(mgw.local_rank()) # pylint: disable=no-member else: config.gpu_options.visible_device_list = '0' # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - weight-sparsified model with tf.variable_scope(self.model_scope): # loss & extra evaluation metrics logits = self.forward_train(images) self.maskable_var_names = [ var.name for var in self.maskable_vars ] loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) tf.summary.scalar('loss', loss) for key, value in metrics.items(): tf.summary.scalar(key, value) # learning rate schedule self.global_step = tf.train.get_or_create_global_step() lrn_rate, self.nb_iters_train = setup_lrn_rate( self.global_step, self.model_name, self.dataset_name) # overall pruning ratios of trainable & maskable variables pr_trainable = calc_prune_ratio(self.trainable_vars) pr_maskable = calc_prune_ratio(self.maskable_vars) tf.summary.scalar('pr_trainable', pr_trainable) tf.summary.scalar('pr_maskable', pr_maskable) # build masks and corresponding operations for weight sparsification self.masks, self.prune_op = self.__build_masks() # optimizer & gradients optimizer_base = tf.train.MomentumOptimizer( lrn_rate, FLAGS.momentum) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) grads_origin = optimizer.compute_gradients( loss, self.trainable_vars) grads_pruned = self.__calc_grads_pruned(grads_origin) # TF operations & model saver self.sess_train = sess with tf.control_dependencies(self.update_ops): self.train_op = optimizer.apply_gradients( grads_pruned, global_step=self.global_step) self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss, pr_trainable, pr_maskable] + list( metrics.values()) self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list( metrics.keys()) self.init_op = tf.variables_initializer(self.vars) self.init_opt_op = tf.variables_initializer( optimizer_base.variables()) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) self.saver_train = tf.train.Saver(self.vars)
def __build_train(self): with tf.Graph().as_default(): # TensorFlow session config = tf.ConfigProto() config.gpu_options.visible_device_list = str(mgw.local_rank() \ if FLAGS.enbl_multi_gpu else 0) self.sess_train = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() images.set_shape((FLAGS.batch_size, images.shape[1], images.shape[2], \ images.shape[3])) # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits( self.sess_train, images) # model definition with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE): # forward pass logits = self.forward_train(images) self.saver_train = tf.train.Saver(self.vars) self.weights = [ v for v in self.trainable_vars if 'kernel' in v.name or 'weight' in v.name ] if not FLAGS.nuql_quantize_all_layers: self.weights = self.weights[1:-1] self.statistics['num_weights'] = \ [tf.reshape(v, [-1]).shape[0].value for v in self.weights] self.__quantize_train_graph() # loss & accuracy # Strictly speaking, clusters should be not included for regularization. loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if self.dataset_name == 'cifar_10': acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.) elif self.dataset_name == 'ilsvrc_12': acc_top1, acc_top5 = metrics['acc_top1'], metrics[ 'acc_top5'] else: raise ValueError("Unrecognized dataset name") model_loss = loss if FLAGS.enbl_dst: dst_loss = self.helper_dst.calc_loss(logits, logits_dst) loss += dst_loss tf.summary.scalar('dst_loss', dst_loss) tf.summary.scalar('model_loss', model_loss) tf.summary.scalar('loss', loss) tf.summary.scalar('acc_top1', acc_top1) tf.summary.scalar('acc_top5', acc_top5) self.ft_step = tf.get_variable('finetune_step', shape=[], dtype=tf.int32, trainable=False) # optimizer & gradients init_lr, bnds, decay_rates, self.finetune_steps = \ setup_bnds_decay_rates(self.model_name, self.dataset_name) lrn_rate = tf.train.piecewise_constant(self.ft_step, [i for i in bnds], \ [init_lr * decay_rate for decay_rate in decay_rates]) # optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) optimizer = tf.train.AdamOptimizer(lrn_rate) if FLAGS.enbl_multi_gpu: optimizer = mgw.DistributedOptimizer(optimizer) # acquire non-cluster-vars and clusters. clusters = [v for v in self.trainable_vars if 'clusters' in v.name] rest_trainable_vars = [v for v in self.trainable_vars \ if v not in clusters] # determine the var_list optimize if FLAGS.nuql_opt_mode in ['cluster', 'both']: if FLAGS.nuql_opt_mode == 'both': optimizable_vars = self.trainable_vars else: optimizable_vars = clusters if FLAGS.nuql_enbl_rl_agent: optimizer_fintune = tf.train.GradientDescentOptimizer( lrn_rate) if FLAGS.enbl_multi_gpu: optimizer_fintune = mgw.DistributedOptimizer( optimizer_fintune) grads_fintune = optimizer_fintune.compute_gradients( loss, var_list=optimizable_vars) elif FLAGS.nuql_opt_mode == 'weights': optimizable_vars = rest_trainable_vars else: raise ValueError("Unknown optimization mode") grads = optimizer.compute_gradients(loss, var_list=optimizable_vars) # sm write graph self.sm_writer.add_graph(self.sess_train.graph) # define the ops with tf.control_dependencies(self.update_ops): self.ops['train'] = optimizer.apply_gradients( grads, global_step=self.ft_step) if FLAGS.nuql_opt_mode in ['both', 'cluster' ] and FLAGS.nuql_enbl_rl_agent: self.ops['rl_fintune'] = optimizer_fintune.apply_gradients( grads_fintune, global_step=self.ft_step) else: self.ops['rl_fintune'] = self.ops['train'] self.ops['summary'] = tf.summary.merge_all() if FLAGS.enbl_dst: self.ops['log'] = [ lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5 ] else: self.ops['log'] = [ lrn_rate, model_loss, loss, acc_top1, acc_top5 ] # NOTE: first run non_cluster_init_op, then cluster_init_op cluster_global_vars = [ v for v in self.vars if 'clusters' in v.name ] # pay attention to beta1_power, beta2_power. non_cluster_global_vars = [ v for v in tf.global_variables() if v not in cluster_global_vars ] self.ops['non_cluster_init'] = tf.variables_initializer( var_list=non_cluster_global_vars) self.ops['cluster_init'] = tf.variables_initializer( var_list=cluster_global_vars) self.ops['bcast'] = mgw.broadcast_global_variables(0) \ if FLAGS.enbl_multi_gpu else None self.ops['reset_ft_step'] = tf.assign(self.ft_step, \ tf.constant(0, dtype=tf.int32)) self.saver_quant = tf.train.Saver(self.vars)
def __build_train(self): # pylint: disable=too-many-locals,too-many-statements """Build the training graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.allow_growth = True # pylint: disable=no-member config.gpu_options.visible_device_list = \ str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - full model with tf.variable_scope(self.model_scope_full): __ = self.forward_train(images) self.vars_full = get_vars_by_scope(self.model_scope_full) self.saver_full = tf.train.Saver(self.vars_full['all']) self.save_path_full = FLAGS.save_path # model definition - channel-pruned model with tf.variable_scope(self.model_scope_prnd): logits_prnd = self.forward_train(images) self.vars_prnd = get_vars_by_scope(self.model_scope_prnd) self.conv_krnl_var_names = [ var.name for var in self.vars_prnd['conv_krnl'] ] self.global_step = tf.train.get_or_create_global_step() self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'] + [self.global_step]) # loss & extra evaluation metrics loss, metrics = self.calc_loss(labels, logits_prnd, self.vars_prnd['trainable']) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits_prnd, logits_dst) tf.summary.scalar('loss', loss) for key, value in metrics.items(): tf.summary.scalar(key, value) # learning rate schedule lrn_rate, self.nb_iters_train = self.setup_lrn_rate( self.global_step) # calculate pruning ratios pr_trainable = calc_prune_ratio(self.vars_prnd['trainable']) pr_conv_krnl = calc_prune_ratio(self.vars_prnd['conv_krnl']) tf.summary.scalar('pr_trainable', pr_trainable) tf.summary.scalar('pr_conv_krnl', pr_conv_krnl) # create masks and corresponding operations for channel pruning self.masks = [] mask_updt_ops = [ ] # update the mask based on convolutional kernel's value for idx, var in enumerate(self.vars_prnd['conv_krnl']): tf.logging.info( 'creating a pruning mask for {} of size {}'.format( var.name, var.shape)) mask_name = '/'.join(var.name.split('/')[1:]).replace( ':0', '_mask') mask_shape = [1, 1, var.shape[2], 1] # 1 x 1 x c_in x 1 mask = tf.get_variable(mask_name, initializer=tf.ones(mask_shape), trainable=False) var_norm = tf.reduce_sum(tf.square(var), axis=[0, 1, 3], keepdims=True) self.masks += [mask] mask_updt_ops += [ mask.assign(tf.cast(var_norm > 0.0, tf.float32)) ] self.mask_updt_op = tf.group(mask_updt_ops) # build operations for channel selection self.__build_chn_select_ops() # optimizer & gradients optimizer_base = tf.train.MomentumOptimizer( lrn_rate, FLAGS.momentum) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) grads_origin = optimizer.compute_gradients( loss, self.vars_prnd['trainable']) grads_pruned = self.__calc_grads_pruned(grads_origin) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_prnd) with tf.control_dependencies(update_ops): self.train_op = optimizer.apply_gradients( grads_pruned, global_step=self.global_step) # TF operations for initializing the channel-pruned model init_ops = [] for var_full, var_prnd in zip(self.vars_full['all'], self.vars_prnd['all']): init_ops += [var_prnd.assign(var_full)] init_ops += [self.global_step.initializer ] # initialize the global step init_ops += [ tf.variables_initializer(optimizer_base.variables()) ] self.init_op = tf.group(init_ops) # TF operations for logging & summarizing self.sess_train = sess self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss, pr_trainable, pr_conv_krnl] + list( metrics.values()) self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_krn'] + list( metrics.keys()) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0)
def __init__(self, dataset_name, weights, statistics, bit_placeholders, ops, layerwise_tune_list, sess_train, sess_eval, saver_train, saver_eval, barrier_fn): """ By passing the ops in the learner, we do not need to build the graph again for training and testing. Args: * dataset_name: a string that indicates which dataset to use * weights: a list of Tensors, the weights of networks to quantize * statistics: a dict, recording the number of weights, activations e.t.c. * bit_placeholders: a dict of placeholder Tensors, the input of bits * ops: a dict of ops, including trian_op, eval_op e.t.c. * layerwise_tune_list: a tuple, in which [0] records the layerwise op and [1] records the layerwise l2_norm * sess_train: a session for train * sess_eval: a session for eval * saver_train: a Tensorflow Saver for the training graph * saver_eval: a Tensorflow Saver for the eval graph * barrier_fn: a function that implements barrier """ self.dataset_name = dataset_name self.weights = weights self.statistics = statistics self.bit_placeholders = bit_placeholders self.ops = ops self.layerwise_tune_ops, self.layerwise_diff = \ layerwise_tune_list[0], layerwise_tune_list[1] self.sess_train = sess_train self.sess_eval = sess_eval self.saver_train = saver_train self.saver_eval = saver_eval self.auto_barrier = barrier_fn self.total_num_weights = sum(self.statistics['num_weights']) self.total_bits = self.total_num_weights * FLAGS.uql_equivalent_bits self.w_rl_helper = RLHelper(self.sess_train, self.total_bits, self.statistics['num_weights'], self.weights, random_layers=FLAGS.uql_enbl_random_layers) self.mgw_size = int(mgw.size()) if FLAGS.enbl_multi_gpu else 1 self.tune_global_steps = int(FLAGS.uql_tune_global_steps / self.mgw_size) self.tune_global_disp_steps = int(FLAGS.uql_tune_disp_steps / self.mgw_size) # build the rl trianing graph with tf.Graph().as_default(): config = tf.ConfigProto() config.gpu_options.visible_device_list = str(mgw.local_rank() \ if FLAGS.enbl_multi_gpu else 0) self.sess_rl = tf.Session(config=config) # train an RL agent through multiple roll-outs self.s_dims = self.w_rl_helper.s_dims self.a_dims = 1 buff_size = len(self.weights) * int(FLAGS.uql_nb_rlouts // 4) self.agent = DdpgAgent(self.sess_rl, self.s_dims, self.a_dims, FLAGS.uql_nb_rlouts, buff_size, a_min=0., a_max=FLAGS.uql_w_bit_max-FLAGS.uql_w_bit_min)
def __build_train(self, model_helper): # pylint: disable=too-many-locals """Build the training graph for the 'optimal' protocol. Args: * model_helper: model helper with definitions of model & dataset """ with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator, __ = model_helper.build_dataset_train( enbl_trn_val_split=True) images, labels = iterator.get_next() # model definition - full-precision network with tf.variable_scope(self.model_scope_full): logits = model_helper.forward_eval( images) # DO NOT USE forward_train() HERE!!! self.vars_full = get_vars_by_scope(self.model_scope_full) self.saver_full = tf.train.Saver(self.vars_full['all']) self.save_path_full = FLAGS.save_path # model definition - weight sparsified network with tf.variable_scope(self.model_scope_prnd): # forward pass & variables' saver logits = model_helper.forward_eval( images) # DO NOT USE forward_train() HERE!!! self.vars_prnd = get_vars_by_scope(self.model_scope_prnd) self.maskable_var_names = [ var.name for var in self.vars_prnd['maskable'] ] loss, __ = model_helper.calc_loss(labels, logits, self.vars_prnd['trainable']) self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all']) self.save_path_prnd = FLAGS.save_path.replace( 'models', 'models_pruned') # build masks for variable pruning self.masks, self.pr_all, self.pr_assign_op = self.__build_masks( ) # create operations for initializing the weight sparsified network init_ops = [] for var_full, var_prnd in zip(self.vars_full['all'], self.vars_prnd['all']): if var_full not in self.vars_full['maskable']: init_ops += [var_prnd.assign(var_full)] else: idx = self.vars_full['maskable'].index(var_full) init_ops += [var_prnd.assign(var_full * self.masks[idx])] self.init_op = tf.group(init_ops) # build operations for layerwise regression & network fine-tuning self.rg_init_op, self.rg_train_ops = self.__build_layer_rg_ops() self.ft_init_op, self.ft_train_op = self.__build_network_ft_ops( loss) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) # create RL helper & agent on the primary worker if is_primary_worker('global'): self.rl_helper, self.agent = self.__build_rl_helper_n_agent( sess) # TF operations self.sess_train = sess
def build_graph(self, is_train): with tf.Graph().as_default(): # TensorFlow session config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.dataset_train.build( ) if is_train else self.dataset_eval.build() images, labels = iterator.get_next() if not isinstance(images, dict): tf.add_to_collection('images_final', images) else: tf.add_to_collection('images_final', images['image']) # model definition - primary model with tf.variable_scope(self.model_scope): # forward pass logits = self.hrnet.forward_train( images) if is_train else self.hrnet.forward_eval(images) if not isinstance(logits, dict): tf.add_to_collection('logits_final', logits) else: for value in logits.values(): tf.add_to_collection('logits_final', value) # loss & extra evaluation metrics loss, metrics = self.hrnet.calc_loss(labels, logits, self.trainable_vars) tf.summary.scalar('loss', loss) for key, value in metrics.items(): tf.summary.scalar(key, value) # optimizer & gradients if is_train: self.global_step = tf.train.get_or_create_global_step() lrn_rate, self.nb_iters_train = self.setup_lrn_rate( self.global_step) optimizer = tf.train.MomentumOptimizer( lrn_rate, self.hrnet.cfg['COMMON']['momentum']) if FLAGS.enbl_multi_gpu: optimizer = mgw.DistributedOptimizer(optimizer) grads = optimizer.compute_gradients( loss, self.trainable_vars) # TF operations & model saver if is_train: self.sess_train = sess with tf.control_dependencies(self.update_ops): self.train_op = optimizer.apply_gradients( grads, global_step=self.global_step) self.summary_op = tf.summary.merge_all() self.sm_writer = tf.summary.FileWriter(logdir=self.log_path) self.log_op = [lrn_rate, loss] + list(metrics.values()) self.log_op_names = ['lr', 'loss'] + list(metrics.keys()) self.init_op = tf.variables_initializer(self.vars) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) self.saver_train = tf.train.Saver(self.vars) else: self.sess_eval = sess self.eval_op = [loss] + list(metrics.values()) self.eval_op_names = ['loss'] + list(metrics.keys()) self.saver_eval = tf.train.Saver(self.vars)
def __build_train(self): # pylint: disable=too-many-locals,too-many-statements """Build the training graph.""" with tf.Graph().as_default() as graph: # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member config.gpu_options.allow_growth = True # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() # model definition - uniform quantized model - part 1 with tf.variable_scope(self.model_scope_quan): logits_quan = self.forward_train(images) if not isinstance(logits_quan, dict): outputs = tf.nn.softmax(logits_quan) else: outputs = tf.nn.softmax(logits_quan['cls_pred']) tf.contrib.quantize.experimental_create_training_graph( weight_bits=FLAGS.uqtf_weight_bits, activation_bits=FLAGS.uqtf_activation_bits, quant_delay=FLAGS.uqtf_quant_delay, freeze_bn_delay=FLAGS.uqtf_freeze_bn_delay, scope=self.model_scope_quan) for node_name in self.unquant_node_names: insert_quant_op(graph, node_name, is_train=True) self.vars_quan = get_vars_by_scope(self.model_scope_quan) self.global_step = tf.train.get_or_create_global_step() self.saver_quan_train = tf.train.Saver(self.vars_quan['all'] + [self.global_step]) # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - full model with tf.variable_scope(self.model_scope_full): __ = self.forward_train(images) self.vars_full = get_vars_by_scope(self.model_scope_full) self.saver_full = tf.train.Saver(self.vars_full['all']) self.save_path_full = FLAGS.save_path # model definition - uniform quantized model - part 2 with tf.variable_scope(self.model_scope_quan): # loss & extra evaluation metrics loss_bsc, metrics = self.calc_loss(labels, logits_quan, self.vars_quan['trainable']) if not FLAGS.enbl_dst: loss_fnl = loss_bsc else: loss_fnl = loss_bsc + self.helper_dst.calc_loss( logits_quan, logits_dst) tf.summary.scalar('loss_bsc', loss_bsc) tf.summary.scalar('loss_fnl', loss_fnl) for key, value in metrics.items(): tf.summary.scalar(key, value) # learning rate schedule lrn_rate, self.nb_iters_train = self.setup_lrn_rate( self.global_step) lrn_rate *= FLAGS.uqtf_lrn_rate_dcy # decrease the learning rate by a constant factor # if self.dataset_name == 'cifar_10': # lrn_rate *= 1e-3 # elif self.dataset_name == 'ilsvrc_12': # lrn_rate *= 1e-4 # else: # raise ValueError('unrecognized dataset\'s name: ' + self.dataset_name) # obtain the full list of trainable variables & update operations self.vars_all = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_quan) self.trainable_vars_all = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.model_scope_quan) self.update_ops_all = tf.get_collection( tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_quan) # TF operations for initializing the uniform quantized model init_ops = [] with tf.control_dependencies( [tf.variables_initializer(self.vars_all)]): for var_full, var_quan in zip(self.vars_full['all'], self.vars_quan['all']): init_ops += [var_quan.assign(var_full)] init_ops += [self.global_step.initializer] self.init_op = tf.group(init_ops) # TF operations for fine-tuning # optimizer_base = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) optimizer_base = tf.train.AdamOptimizer(lrn_rate) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) grads = optimizer.compute_gradients(loss_fnl, self.trainable_vars_all) with tf.control_dependencies(self.update_ops_all): self.train_op = optimizer.apply_gradients( grads, global_step=self.global_step) self.init_opt_op = tf.variables_initializer( optimizer_base.variables()) # TF operations for logging & summarizing self.sess_train = sess self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss_fnl] + list(metrics.values()) self.log_op_names = ['lr', 'loss'] + list(metrics.keys()) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0)
def __build_prune(self): """Build the channel pruning graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.allow_growth = True # pylint: disable=no-member config.gpu_options.visible_device_list = \ str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() if not isinstance(images, dict): images_ph = tf.placeholder(tf.float32, shape=images.shape, name='images_ph') else: images_ph = {} for key, value in images.items(): images_ph[key] = tf.placeholder(value.dtype, shape=value.shape, name=(key + '_ph')) # restore a pre-trained model as full model with tf.variable_scope(self.model_scope_full): __ = self.forward_train(images_ph) vars_full = get_vars_by_scope(self.model_scope_full) saver_full = tf.train.Saver(vars_full['all']) saver_full.restore( sess, tf.train.latest_checkpoint(os.path.dirname( FLAGS.save_path))) # restore a pre-trained model as channel-pruned model with tf.variable_scope(self.model_scope_prnd): logits_prnd = self.forward_train(images_ph) vars_prnd = get_vars_by_scope(self.model_scope_prnd) global_step = tf.train.get_or_create_global_step() saver_prnd = tf.train.Saver(vars_prnd['all'] + [global_step]) # loss & extra evaluation metrics loss, metrics = self.calc_loss(labels, logits_prnd, vars_prnd['trainable']) # calculate pruning ratios pr_trainable = calc_prune_ratio(vars_prnd['trainable']) pr_conv_krnl = calc_prune_ratio(vars_prnd['conv_krnl']) # use full model's weights to initialize channel-pruned model init_ops = [global_step.initializer] for var_full, var_prnd in zip(vars_full['all'], vars_prnd['all']): init_ops += [var_prnd.assign(var_full)] self.init_op_prune = tf.group(init_ops) # build a list of Conv2D operation's information self.conv_info_list = self.__build_conv_info_list( vars_prnd['conv_krnl']) # build meta LASSO/least-square optimization problems self.meta_lasso = self.__build_meta_lasso() self.meta_lstsq = self.__build_meta_lstsq() # TF operations for logging & summarizing self.sess_prune = sess self.images_prune = images self.images_prune_ph = images_ph self.saver_prune = saver_prnd self.pr_trn_prune = pr_trainable self.pr_krn_prune = pr_conv_krnl