Beispiel #1
0
 def var_avg(self, global_step=None):
   ''' average model variables, add average_op to UPDATES_OPS'''
   model_avg_conf = self.config['solver']['model_average']
   var_avg_model = model_avg_conf['enable']
   if var_avg_model:
     var_avg_decay = model_avg_conf['var_avg_decay']
     variable_averages = self.get_var_avg_ema(var_avg_decay, global_step)
     apply_op = variable_averages.apply(tf.trainable_variables())
     tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, apply_op)
     utils.log_vars('Avg Trainable Vars', tf.trainable_variables())
Beispiel #2
0
 def l2_loss(self, tvars=None):
     _l2_loss = 0.0
     weight_decay = self.config['solver']['optimizer'].get(
         'weight_decay', None)
     if weight_decay:
         logging.info(f"add L2 Loss with decay: {weight_decay}")
         with tf.name_scope('l2_loss'):
             tvars = tvars if tvars else tf.trainable_variables()
             tvars = [v for v in tvars if 'bias' not in v.name]
             _l2_loss = weight_decay * tf.add_n(
                 [tf.nn.l2_loss(v) for v in tvars])
             summary_lib.scalar('l2_loss', _l2_loss)
     return _l2_loss
Beispiel #3
0
  def make_restore_average_vars_dict(self, global_step=None):
    ''' using vars_average to restotre vars'''
    model_avg_conf = self.config['solver']['model_average']
    var_avg_decay = model_avg_conf['var_avg_decay']

    var_restore_dict = {}
    variable_averages = self.get_var_avg_ema(var_avg_decay, global_step)
    for var in tf.global_variables():
      if var in tf.trainable_variables():
        name = variable_averages.average_name(var)
      else:
        name = var.op.name
      var_restore_dict[name] = var
    return var_restore_dict
    def init_from_checkpoint(self):
        ''' do transfer learning by init sub vars from other checkpoint. '''
        if 'transfer' not in self.config['solver']:
            return
        transfer_cfg = self.config['solver']['transfer']
        enable = transfer_cfg['enable']
        if not enable:
            return
        init_checkpoint = transfer_cfg['ckpt_path']
        exclude = transfer_cfg['exclude_reg']
        include = transfer_cfg['include_reg']
        logging.info(f"Transfer from checkpoint: {init_checkpoint}")
        logging.info(f"Transfer exclude: {exclude}")
        logging.info(f"Transfer include: {include}")

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        if init_checkpoint:

            def _filter_by_reg(tvars, include, exclude):
                include = include if include else []
                exclude = exclude if exclude else []
                outs = []
                for var in tvars:
                    name = var.name
                    for reg_str in include:
                        logging.debug(f"var:{name}, reg: {reg_str}")
                        m = re.match(reg_str, name)
                        if m is not None:
                            outs.append(var)
                    for reg_str in exclude:
                        logging.debug(f"var:{name}, reg: {reg_str}")
                        m = re.match(reg_str, name)
                        if m is None:
                            outs.append(var)
                return outs

            tvars = _filter_by_reg(tvars, include, exclude)
            assignment_map, initialized_variable_names = \
              self.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                         init_string)