Ejemplo n.º 1
0
    def create_init_fn_to_restore(self,
                                  master_checkpoint,
                                  inception_checkpoint=None):
        """Creates an init operations to restore weights from various checkpoints.

    Args:
      master_checkpoint: path to a checkpoint which contains all weights for the
        whole model.
      inception_checkpoint: path to a checkpoint which contains weights for the
        inception part only.

    Returns:
      a function to run initialization ops.
    """
        all_assign_ops = []
        all_feed_dict = {}

        def assign_from_checkpoint(variables, checkpoint):
            logging.info('Request to re-store %d weights from %s',
                         len(variables), checkpoint)
            if not variables:
                logging.error('Can\'t find any variables to restore.')
                sys.exit(1)
            assign_op, feed_dict = slim.assign_from_checkpoint(
                checkpoint, variables)
            all_assign_ops.append(assign_op)
            all_feed_dict.update(feed_dict)

        # For the reduced charset restore all variables except AttentionOcr_v1/sequence_logit_fn/SQLR/softmax_w

        logging.info('\n variables_to_restore:\n%s',
                     utils.variables_to_restore().keys())
        logging.info(
            'moving_average_variables:\n%s',
            [v.op.name for v in tf.compat.v1.moving_average_variables()])
        logging.info('trainable_variables:\n%s',
                     [v.op.name for v in tf.compat.v1.trainable_variables()])
        if master_checkpoint:
            assign_from_checkpoint(utils.variables_to_restore(),
                                   master_checkpoint)

        if inception_checkpoint:
            variables = utils.variables_to_restore(
                'AttentionOcr_v1/conv_tower_fn/INCE', strip_scope=True)
            assign_from_checkpoint(variables, inception_checkpoint)

        def init_assign_fn(sess):
            logging.info('Restoring checkpoint(s)')
            sess.run(all_assign_ops, all_feed_dict)

        return init_assign_fn
Ejemplo n.º 2
0
  def create_init_fn_to_restore(self, master_checkpoint,
                                caps_checkpoint=None,
                                trainable_base=True):
    """Creates an init operations to restore weights from various checkpoints.

    Args:
     master_checkpoint: path to a checkpointwhich contains all weights for
        the whole model.
      inception_checkpoint: path to a checkpoint which contains weights for the
        inception part only.

    Returns:
      a function to run initialization ops.
    """
    all_assign_ops = []
    all_feed_dict = {}

    def assign_from_checkpoint(variables, checkpoint):
      logging.info('Request to re-store %d weights from %s',
                   len(variables), checkpoint)
      if not variables:
        logging.error('Can\'t find any variables to restore.')
        sys.exit(1)
      assign_op, feed_dict = slim.assign_from_checkpoint(checkpoint, variables)
      all_assign_ops.append(assign_op)
      all_feed_dict.update(feed_dict)

    logging.info('variables_to_restore:\n%s' % utils.variables_to_restore().keys())
    logging.info('moving_average_variables:\n%s' % [v.op.name for v in tf.moving_average_variables()])
    logging.info('trainable_variables:\n%s' % [v.op.name for v in tf.trainable_variables()])
    if master_checkpoint:
      assign_from_checkpoint(utils.variables_to_restore(), master_checkpoint)

    if caps_checkpoint:
      variables = utils.variables_to_restore(
                                        'AttentionOcr_v1/caps_fn/CAPS', 
                                        strip_scope=True)
      if not trainable_base:
        utils.variables_to_freeze(variables)
      assign_from_checkpoint(variables, caps_checkpoint)

    def init_assign_fn(sess):
      logging.info('Restoring checkpoint(s)')
      sess.run(all_assign_ops, all_feed_dict)

    return init_assign_fn
Ejemplo n.º 3
0
  def create_init_fn_to_restore(self, master_checkpoint,
                                inception_checkpoint=None):
    """Creates an init operations to restore weights from various checkpoints.

    Args:
      master_checkpoint: path to a checkpoint which contains all weights for
        the whole model.
      inception_checkpoint: path to a checkpoint which contains weights for the
        inception part only.

    Returns:
      a function to run initialization ops.
    """
    all_assign_ops = []
    all_feed_dict = {}

    def assign_from_checkpoint(variables, checkpoint):
      logging.info('Request to re-store %d weights from %s',
                   len(variables), checkpoint)
      if not variables:
        logging.error('Can\'t find any variables to restore.')
        sys.exit(1)
      assign_op, feed_dict = slim.assign_from_checkpoint(checkpoint, variables)
      all_assign_ops.append(assign_op)
      all_feed_dict.update(feed_dict)

    logging.info('variables_to_restore:\n%s' % utils.variables_to_restore().keys())
    logging.info('moving_average_variables:\n%s' % [v.op.name for v in tf.moving_average_variables()])
    logging.info('trainable_variables:\n%s' % [v.op.name for v in tf.trainable_variables()])
    if master_checkpoint:
      assign_from_checkpoint(utils.variables_to_restore(), master_checkpoint)

    if inception_checkpoint:
      variables = utils.variables_to_restore(
        'AttentionOcr_v1/conv_tower_fn/INCE', strip_scope=True)
      assign_from_checkpoint(variables, inception_checkpoint)

    def init_assign_fn(sess):
      logging.info('Restoring checkpoint(s)')
      sess.run(all_assign_ops, all_feed_dict)

    return init_assign_fn
Ejemplo n.º 4
0
    def create_init_fn_to_restore(self,
                                  master_checkpoint,
                                  inception_checkpoint=None):
        """Creates an init operations to restore weights from various checkpoints.

    Args:
      master_checkpoint: path to a checkpoint which contains all weights for
        the whole model.
      inception_checkpoint: path to a checkpoint which contains weights for the
        inception part only.

    Returns:
      a function to run initialization ops.
    """
        all_assign_ops = []
        all_feed_dict = {}

        def assign_from_checkpoint(variables, checkpoint):
            logging.info('Request to re-store %d weights from %s',
                         len(variables), checkpoint)
            if not variables:
                logging.error('Can\'t find any variables to restore.')
                sys.exit(1)
            assign_op, feed_dict = slim.assign_from_checkpoint(
                checkpoint, variables)
            all_assign_ops.append(assign_op)
            all_feed_dict.update(feed_dict)

        if master_checkpoint:
            assign_from_checkpoint(utils.variables_to_restore(),
                                   master_checkpoint)

        if inception_checkpoint:
            variables = utils.variables_to_restore(
                'AttentionOcr_v1/conv_tower_fn/INCE', strip_scope=True)
            assign_from_checkpoint(variables, inception_checkpoint)

        def init_assign_fn(sess):
            logging.info('Restoring checkpoint(s)')
            sess.run(all_assign_ops, all_feed_dict)

        return init_assign_fn