def __build_train(self): # pylint: disable=too-many-locals """Build the training graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() if FLAGS.enbl_multi_gpu: config.gpu_options.visible_device_list = str(mgw.local_rank()) # pylint: disable=no-member else: config.gpu_options.visible_device_list = '0' # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - weight-sparsified model with tf.variable_scope(self.model_scope): # loss & extra evaluation metrics logits = self.forward_train(images) self.maskable_var_names = [ var.name for var in self.maskable_vars ] loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) tf.summary.scalar('loss', loss) for key, value in metrics.items(): tf.summary.scalar(key, value) # learning rate schedule self.global_step = tf.train.get_or_create_global_step() lrn_rate, self.nb_iters_train = setup_lrn_rate( self.global_step, self.model_name, self.dataset_name) # overall pruning ratios of trainable & maskable variables pr_trainable = calc_prune_ratio(self.trainable_vars) pr_maskable = calc_prune_ratio(self.maskable_vars) tf.summary.scalar('pr_trainable', pr_trainable) tf.summary.scalar('pr_maskable', pr_maskable) # build masks and corresponding operations for weight sparsification self.masks, self.prune_op = self.__build_masks() # optimizer & gradients optimizer_base = tf.train.MomentumOptimizer( lrn_rate, FLAGS.momentum) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) grads_origin = optimizer.compute_gradients( loss, self.trainable_vars) grads_pruned = self.__calc_grads_pruned(grads_origin) # TF operations & model saver self.sess_train = sess with tf.control_dependencies(self.update_ops): self.train_op = optimizer.apply_gradients( grads_pruned, global_step=self.global_step) self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss, pr_trainable, pr_maskable] + list( metrics.values()) self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list( metrics.keys()) self.init_op = tf.variables_initializer(self.vars) self.init_opt_op = tf.variables_initializer( optimizer_base.variables()) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) self.saver_train = tf.train.Saver(self.vars)
def __build_train(self): # pylint: disable=too-many-locals,too-many-statements """Build the training graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - full model with tf.variable_scope(self.model_scope_full): __ = self.forward_train(images) self.vars_full = get_vars_by_scope(self.model_scope_full) self.saver_full = tf.train.Saver(self.vars_full['all']) self.save_path_full = FLAGS.save_path # model definition - channel-pruned model with tf.variable_scope(self.model_scope_prnd): logits_prnd = self.forward_train(images) self.vars_prnd = get_vars_by_scope(self.model_scope_prnd) self.maskable_var_names = [ var.name for var in self.vars_prnd['maskable'] ] self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all']) # loss & extra evaluation metrics loss_bsc, metrics = self.calc_loss(labels, logits_prnd, self.vars_prnd['trainable']) if not FLAGS.enbl_dst: loss_fnl = loss_bsc else: loss_fnl = loss_bsc + self.helper_dst.calc_loss( logits_prnd, logits_dst) tf.summary.scalar('loss_bsc', loss_bsc) tf.summary.scalar('loss_fnl', loss_fnl) for key, value in metrics.items(): tf.summary.scalar(key, value) # learning rate schedule self.global_step = tf.train.get_or_create_global_step() lrn_rate, self.nb_iters_train = setup_lrn_rate( self.global_step, self.model_name, self.dataset_name) # overall pruning ratios of trainable & maskable variables pr_trainable = calc_prune_ratio(self.vars_prnd['trainable']) pr_maskable = calc_prune_ratio(self.vars_prnd['maskable']) tf.summary.scalar('pr_trainable', pr_trainable) tf.summary.scalar('pr_maskable', pr_maskable) # create masks and corresponding operations for channel pruning self.masks = [] self.mask_deltas = [] self.mask_init_ops = [] self.mask_updt_ops = [] self.prune_ops = [] for idx, var in enumerate(self.vars_prnd['maskable']): name = '/'.join(var.name.split('/')[1:]).replace( ':0', '_mask') self.masks += [ tf.get_variable(name, initializer=tf.ones(var.shape), trainable=False) ] name = '/'.join(var.name.split('/')[1:]).replace( ':0', '_mask_delta') self.mask_deltas += [ tf.placeholder(tf.float32, shape=var.shape, name=name) ] self.mask_init_ops += [ self.masks[idx].assign(tf.zeros(var.shape)) ] self.mask_updt_ops += [ self.masks[idx].assign_add(self.mask_deltas[idx]) ] self.prune_ops += [var.assign(var * self.masks[idx])] # build extra losses for regression & discrimination self.reg_losses, self.dis_losses, self.idxs_layer_to_block = \ self.__build_extra_losses(labels) self.dis_losses += [ loss_bsc ] # append discrimination-aware loss for the last block self.nb_layers = len(self.reg_losses) self.nb_blocks = len(self.dis_losses) for idx, reg_loss in enumerate(self.reg_losses): tf.summary.scalar('reg_loss_%d' % idx, reg_loss) for idx, dis_loss in enumerate(self.dis_losses): tf.summary.scalar('dis_loss_%d' % idx, dis_loss) # obtain the full list of trainable variables & update operations self.vars_all = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_prnd) self.trainable_vars_all = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.model_scope_prnd) self.update_ops_all = tf.get_collection( tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_prnd) # TF operations for initializing the channel-pruned model init_ops = [] with tf.control_dependencies( [tf.variables_initializer(self.vars_all)]): for var_full, var_prnd in zip(self.vars_full['all'], self.vars_prnd['all']): init_ops += [var_prnd.assign(var_full)] self.init_op = tf.group(init_ops) # TF operations for layer-wise, block-wise, and whole-network fine-tuning self.layer_train_ops, self.layer_init_opt_ops, self.grad_norms = self.__build_layer_ops( ) self.block_train_ops, self.block_init_opt_ops = self.__build_block_ops( ) self.train_op, self.init_opt_op = self.__build_network_ops( loss_fnl, lrn_rate) # TF operations for logging & summarizing self.sess_train = sess self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss_fnl, pr_trainable, pr_maskable ] + list(metrics.values()) self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list( metrics.keys()) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0)
def __build_train(self): # pylint: disable=too-many-locals,too-many-statements """Build the training graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() images.set_shape((FLAGS.batch_size, images.shape[1], images.shape[2], images.shape[3])) # model definition - uniform quantized model - part 1 with tf.variable_scope(self.model_scope_quan): logits_quan = self.forward_train(images) tf.contrib.quantize.experimental_create_training_graph( weight_bits=FLAGS.uqtf_weight_bits, activation_bits=FLAGS.uqtf_activation_bits, quant_delay=FLAGS.uqtf_quant_delay, freeze_bn_delay=FLAGS.uqtf_freeze_bn_delay, scope=self.model_scope_quan) self.vars_quan = get_vars_by_scope(self.model_scope_quan) self.saver_quan_train = tf.train.Saver(self.vars_quan['all']) # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - full model with tf.variable_scope(self.model_scope_full): __ = self.forward_train(images) self.vars_full = get_vars_by_scope(self.model_scope_full) self.saver_full = tf.train.Saver(self.vars_full['all']) self.save_path_full = FLAGS.save_path # model definition - uniform quantized model - part 2 with tf.variable_scope(self.model_scope_quan): # loss & extra evaluation metrics loss_bsc, metrics = self.calc_loss(labels, logits_quan, self.vars_quan['trainable']) if not FLAGS.enbl_dst: loss_fnl = loss_bsc else: loss_fnl = loss_bsc + self.helper_dst.calc_loss( logits_quan, logits_dst) tf.summary.scalar('loss_bsc', loss_bsc) tf.summary.scalar('loss_fnl', loss_fnl) for key, value in metrics.items(): tf.summary.scalar(key, value) # learning rate schedule self.global_step = tf.train.get_or_create_global_step() lrn_rate, self.nb_iters_train = setup_lrn_rate( self.global_step, self.model_name, self.dataset_name) lrn_rate *= FLAGS.uqtf_lrn_rate_dcy # decrease the learning rate by a constant factor #if self.dataset_name == 'cifar_10': # lrn_rate *= 1e-3 #elif self.dataset_name == 'ilsvrc_12': # lrn_rate *= 1e-4 #else: # raise ValueError('unrecognized dataset\'s name: ' + self.dataset_name) # obtain the full list of trainable variables & update operations self.vars_all = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_quan) self.trainable_vars_all = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.model_scope_quan) self.update_ops_all = tf.get_collection( tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_quan) # TF operations for initializing the uniform quantized model init_ops = [] with tf.control_dependencies( [tf.variables_initializer(self.vars_all)]): for var_full, var_quan in zip(self.vars_full['all'], self.vars_quan['all']): init_ops += [var_quan.assign(var_full)] self.init_op = tf.group(init_ops) # TF operations for fine-tuning optimizer_base = tf.train.MomentumOptimizer( lrn_rate, FLAGS.momentum) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) grads = optimizer.compute_gradients(loss_fnl, self.trainable_vars_all) with tf.control_dependencies(self.update_ops_all): self.train_op = optimizer.apply_gradients( grads, global_step=self.global_step) self.init_opt_op = tf.variables_initializer( optimizer_base.variables()) # TF operations for logging & summarizing self.sess_train = sess self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss_fnl] + list(metrics.values()) self.log_op_names = ['lr', 'loss'] + list(metrics.keys()) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0)
def __build_pruned_train_model(self, path=None, finetune=False): # pylint: disable=too-many-locals ''' build a training model from pruned model ''' if path is None: path = FLAGS.save_path with tf.Graph().as_default(): config = tf.ConfigProto() config.gpu_options.visible_device_list = str(# pylint: disable=no-member mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) self.sess_train = tf.Session(config=config) self.saver_train = tf.train.import_meta_graph(path + '.meta') self.saver_train.restore(self.sess_train, path) logits = tf.get_collection('logits')[0] train_images = tf.get_collection('train_images')[0] train_labels = tf.get_collection('train_labels')[0] mem_images = tf.get_collection('mem_images')[0] mem_labels = tf.get_collection('mem_labels')[0] self.sess_train.close() graph_editor.reroute_ts(train_images, mem_images) graph_editor.reroute_ts(train_labels, mem_labels) self.sess_train = tf.Session(config=config) self.saver_train.restore(self.sess_train, path) trainable_vars = self.trainable_vars loss, accuracy = self.calc_loss(train_labels, logits, trainable_vars) self.accuracy_keys = list(accuracy.keys()) if FLAGS.enbl_dst: logits_dst = self.learner_dst.calc_logits(self.sess_train, train_images) loss += self.learner_dst.calc_loss(logits, logits_dst) tf.summary.scalar('loss', loss) for key in accuracy.keys(): tf.summary.scalar(key, accuracy[key]) self.summary_op = tf.summary.merge_all() global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, trainable=False) self.global_step = global_step lrn_rate, self.nb_iters_train = setup_lrn_rate( self.global_step, self.model_name, self.dataset_name) if finetune and not FLAGS.cp_retrain: mom_optimizer = tf.train.AdamOptimizer(FLAGS.cp_lrn_rate_ft) self.log_op = [tf.constant(FLAGS.cp_lrn_rate_ft), loss, list(accuracy.values())] else: mom_optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) self.log_op = [lrn_rate, loss, list(accuracy.values())] if FLAGS.enbl_multi_gpu: optimizer = mgw.DistributedOptimizer(mom_optimizer) else: optimizer = mom_optimizer grads_origin = optimizer.compute_gradients(loss, trainable_vars) grads_pruned, masks = self.__calc_grads_pruned(grads_origin) with tf.control_dependencies(self.update_ops): self.train_op = optimizer.apply_gradients(grads_pruned, global_step=global_step) self.sm_writer.add_graph(tf.get_default_graph()) self.train_init_op = \ tf.initialize_variables(mom_optimizer.variables() + [global_step] + masks) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0)
def __build(self, is_train): # pylint: disable=too-many-locals """Build the training / evaluation graph. Args: * is_train: whether to create the training graph """ with tf.Graph().as_default(): # TensorFlow session config = tf.ConfigProto() config.gpu_options.visible_device_list = str( mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train( ) if is_train else self.build_dataset_eval() images, labels = iterator.get_next() tf.add_to_collection('images_final', images) # model definition - distilled model if self.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - primary model with tf.variable_scope(self.model_scope): # forward pass logits = self.forward_train( images) if is_train else self.forward_eval(images) # out = self.forward_eval(images) tf.add_to_collection('logits_final', logits) # loss & extra evalution metrics loss, metrics = self.calc_loss(labels, logits, self.trainable_vars) if self.enbl_dst: loss += self.helper_dst.calc_loss(logits, logits_dst) tf.summary.scalar('loss', loss) for key, value in metrics.items(): tf.summary.scalar(key, value) # optimizer & gradients if is_train: self.global_step = tf.train.get_or_create_global_step() lrn_rate, self.nb_iters_train = setup_lrn_rate( self.global_step, self.model_name, self.dataset_name) optimizer = tf.train.MomentumOptimizer( lrn_rate, FLAGS.momentum) # optimizer = tf.train.AdamOptimizer(lrn_rate) if FLAGS.enbl_multi_gpu: optimizer = mgw.DistributedOptimizer(optimizer) grads = optimizer.compute_gradients( loss, self.trainable_vars) # TF operations & model saver if is_train: self.sess_train = sess with tf.control_dependencies(self.update_ops): self.train_op = optimizer.apply_gradients( grads, global_step=self.global_step) self.summary_op = tf.summary.merge_all() self.log_op = [lrn_rate, loss] + list(metrics.values()) self.log_op_names = ['lr', 'loss'] + list(metrics.keys()) self.init_op = tf.variables_initializer(self.vars) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0) self.saver_train = tf.train.Saver(self.vars) else: self.sess_eval = sess self.factory_op = [tf.cast(logits, tf.uint8)] self.time_op = [logits] self.out_op = [ tf.cast(images, tf.uint8), tf.cast(logits, tf.uint8), tf.cast(labels, tf.uint8) ] self.eval_op = [loss] + list(metrics.values()) self.eval_op_names = ['loss'] + list(metrics.keys()) self.saver_eval = tf.train.Saver(self.vars)
def __build_train(self): # pylint: disable=too-many-locals,too-many-statements """Build the training graph.""" with tf.Graph().as_default(): # create a TF session for the current graph config = tf.ConfigProto() config.gpu_options.allow_growth = True # pylint: disable=no-member config.gpu_options.visible_device_list = \ str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) # pylint: disable=no-member sess = tf.Session(config=config) # data input pipeline with tf.variable_scope(self.data_scope): iterator = self.build_dataset_train() images, labels = iterator.get_next() # model definition - distilled model if FLAGS.enbl_dst: logits_dst = self.helper_dst.calc_logits(sess, images) # model definition - channel-pruned model with tf.variable_scope(self.model_scope_prnd): logits_prnd = self.forward_train(images) self.vars_prnd = get_vars_by_scope(self.model_scope_prnd) self.global_step = tf.train.get_or_create_global_step() self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'] + [self.global_step]) # loss & extra evaluation metrics loss, metrics = self.calc_loss(labels, logits_prnd, self.vars_prnd['trainable']) if FLAGS.enbl_dst: loss += self.helper_dst.calc_loss(logits_prnd, logits_dst) tf.summary.scalar('loss', loss) for key, value in metrics.items(): tf.summary.scalar(key, value) # learning rate schedule # lrn_rate, self.nb_iters_train = self.setup_lrn_rate(self.global_step) lrn_rate, self.nb_iters_train = setup_lrn_rate( self.global_step, self.model_name, self.dataset_name) # calculate pruning ratios pr_trainable = calc_prune_ratio(self.vars_prnd['trainable']) pr_conv_krnl = calc_prune_ratio(self.vars_prnd['conv_krnl']) tf.summary.scalar('pr_trainable', pr_trainable) tf.summary.scalar('pr_conv_krnl', pr_conv_krnl) # create masks and corresponding operations for channel pruning self.masks = [] for idx, var in enumerate(self.vars_prnd['conv_krnl']): tf.logging.info('creating a pruning mask for {} of size {}'.format(var.name, var.shape)) mask_name = '/'.join(var.name.split('/')[1:]).replace(':0', '_mask') var_norm = tf.reduce_sum(tf.square(var), axis=[0, 1, 3], keepdims=True) mask_init = tf.cast(var_norm > 0.0, tf.float32) mask = tf.get_variable(mask_name, initializer=mask_init, trainable=False) self.masks += [mask] # optimizer & gradients optimizer_base = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) if not FLAGS.enbl_multi_gpu: optimizer = optimizer_base else: optimizer = mgw.DistributedOptimizer(optimizer_base) grads_origin = optimizer.compute_gradients(loss, self.vars_prnd['trainable']) grads_pruned = self.__calc_grads_pruned(grads_origin) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_prnd) with tf.control_dependencies(update_ops): self.train_op = optimizer.apply_gradients(grads_pruned, global_step=self.global_step) # TF operations for logging & summarizing self.sess_train = sess self.summary_op = tf.summary.merge_all() self.init_op = tf.group( tf.variables_initializer([self.global_step] + self.masks + optimizer_base.variables())) self.log_op = [lrn_rate, loss, pr_trainable, pr_conv_krnl] + list(metrics.values()) self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_krn'] + list(metrics.keys()) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0)