Пример #1
0
    def restore_from_classification_checkpoint_fn(
            self, first_stage_feature_extractor_scope,
            second_stage_feature_extractor_scope):
        """Returns a map of variables to load from a foreign checkpoint.

    Note that this overrides the default implementation in
    faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for
    PNASNet checkpoints.

    Args:
      first_stage_feature_extractor_scope: A scope name for the first stage
        feature extractor.
      second_stage_feature_extractor_scope: A scope name for the second stage
        feature extractor.

    Returns:
      A dict mapping variable names (to load from a checkpoint) to variables in
      the model graph.
    """
        variables_to_restore = {}
        for variable in variables_helper.get_global_variables_safely():
            if variable.op.name.startswith(
                    first_stage_feature_extractor_scope):
                var_name = variable.op.name.replace(
                    first_stage_feature_extractor_scope + '/', '')
                var_name += '/ExponentialMovingAverage'
                variables_to_restore[var_name] = variable
            if variable.op.name.startswith(
                    second_stage_feature_extractor_scope):
                var_name = variable.op.name.replace(
                    second_stage_feature_extractor_scope + '/', '')
                var_name += '/ExponentialMovingAverage'
                variables_to_restore[var_name] = variable
        return variables_to_restore
Пример #2
0
  def restore_from_classification_checkpoint_fn(
      self,
      first_stage_feature_extractor_scope,
      second_stage_feature_extractor_scope):
    """Returns a map of variables to load from a foreign checkpoint.

    Note that this overrides the default implementation in
    faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for
    InceptionResnetV2 checkpoints.

    TODO(jonathanhuang,rathodv): revisit whether it's possible to force the
    `Repeat` namescope as created in `_extract_box_classifier_features` to
    start counting at 2 (e.g. `Repeat_2`) so that the default restore_fn can
    be used.

    Args:
      first_stage_feature_extractor_scope: A scope name for the first stage
        feature extractor.
      second_stage_feature_extractor_scope: A scope name for the second stage
        feature extractor.

    Returns:
      A dict mapping variable names (to load from a checkpoint) to variables in
      the model graph.
    """

    variables_to_restore = {}
    for variable in variables_helper.get_global_variables_safely():
      if variable.op.name.startswith(
          first_stage_feature_extractor_scope):
        var_name = variable.op.name.replace(
            first_stage_feature_extractor_scope + '/', '')
        variables_to_restore[var_name] = variable
      if variable.op.name.startswith(
          second_stage_feature_extractor_scope):
        var_name = variable.op.name.replace(
            second_stage_feature_extractor_scope
            + '/InceptionResnetV2/Repeat', 'InceptionResnetV2/Repeat_2')
        var_name = var_name.replace(
            second_stage_feature_extractor_scope + '/', '')
        variables_to_restore[var_name] = variable
    return variables_to_restore