def var_avg(self, global_step=None): ''' average model variables, add average_op to UPDATES_OPS''' model_avg_conf = self.config['solver']['model_average'] var_avg_model = model_avg_conf['enable'] if var_avg_model: var_avg_decay = model_avg_conf['var_avg_decay'] variable_averages = self.get_var_avg_ema(var_avg_decay, global_step) apply_op = variable_averages.apply(tf.trainable_variables()) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, apply_op) utils.log_vars('Avg Trainable Vars', tf.trainable_variables())
def l2_loss(self, tvars=None): _l2_loss = 0.0 weight_decay = self.config['solver']['optimizer'].get( 'weight_decay', None) if weight_decay: logging.info(f"add L2 Loss with decay: {weight_decay}") with tf.name_scope('l2_loss'): tvars = tvars if tvars else tf.trainable_variables() tvars = [v for v in tvars if 'bias' not in v.name] _l2_loss = weight_decay * tf.add_n( [tf.nn.l2_loss(v) for v in tvars]) summary_lib.scalar('l2_loss', _l2_loss) return _l2_loss
def make_restore_average_vars_dict(self, global_step=None): ''' using vars_average to restotre vars''' model_avg_conf = self.config['solver']['model_average'] var_avg_decay = model_avg_conf['var_avg_decay'] var_restore_dict = {} variable_averages = self.get_var_avg_ema(var_avg_decay, global_step) for var in tf.global_variables(): if var in tf.trainable_variables(): name = variable_averages.average_name(var) else: name = var.op.name var_restore_dict[name] = var return var_restore_dict
def init_from_checkpoint(self): ''' do transfer learning by init sub vars from other checkpoint. ''' if 'transfer' not in self.config['solver']: return transfer_cfg = self.config['solver']['transfer'] enable = transfer_cfg['enable'] if not enable: return init_checkpoint = transfer_cfg['ckpt_path'] exclude = transfer_cfg['exclude_reg'] include = transfer_cfg['include_reg'] logging.info(f"Transfer from checkpoint: {init_checkpoint}") logging.info(f"Transfer exclude: {exclude}") logging.info(f"Transfer include: {include}") tvars = tf.trainable_variables() initialized_variable_names = {} if init_checkpoint: def _filter_by_reg(tvars, include, exclude): include = include if include else [] exclude = exclude if exclude else [] outs = [] for var in tvars: name = var.name for reg_str in include: logging.debug(f"var:{name}, reg: {reg_str}") m = re.match(reg_str, name) if m is not None: outs.append(var) for reg_str in exclude: logging.debug(f"var:{name}, reg: {reg_str}") m = re.match(reg_str, name) if m is None: outs.append(var) return outs tvars = _filter_by_reg(tvars, include, exclude) assignment_map, initialized_variable_names = \ self.get_assignment_map_from_checkpoint(tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string)