def restore_from_classification_checkpoint_fn( self, first_stage_feature_extractor_scope, second_stage_feature_extractor_scope): """Returns a map of variables to load from a foreign checkpoint. Note that this overrides the default implementation in faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for PNASNet checkpoints. Args: first_stage_feature_extractor_scope: A scope name for the first stage feature extractor. second_stage_feature_extractor_scope: A scope name for the second stage feature extractor. Returns: A dict mapping variable names (to load from a checkpoint) to variables in the model graph. """ variables_to_restore = {} for variable in variables_helper.get_global_variables_safely(): if variable.op.name.startswith( first_stage_feature_extractor_scope): var_name = variable.op.name.replace( first_stage_feature_extractor_scope + '/', '') var_name += '/ExponentialMovingAverage' variables_to_restore[var_name] = variable if variable.op.name.startswith( second_stage_feature_extractor_scope): var_name = variable.op.name.replace( second_stage_feature_extractor_scope + '/', '') var_name += '/ExponentialMovingAverage' variables_to_restore[var_name] = variable return variables_to_restore
def restore_from_classification_checkpoint_fn( self, first_stage_feature_extractor_scope, second_stage_feature_extractor_scope): """Returns a map of variables to load from a foreign checkpoint. Note that this overrides the default implementation in faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for InceptionResnetV2 checkpoints. TODO(jonathanhuang,rathodv): revisit whether it's possible to force the `Repeat` namescope as created in `_extract_box_classifier_features` to start counting at 2 (e.g. `Repeat_2`) so that the default restore_fn can be used. Args: first_stage_feature_extractor_scope: A scope name for the first stage feature extractor. second_stage_feature_extractor_scope: A scope name for the second stage feature extractor. Returns: A dict mapping variable names (to load from a checkpoint) to variables in the model graph. """ variables_to_restore = {} for variable in variables_helper.get_global_variables_safely(): if variable.op.name.startswith( first_stage_feature_extractor_scope): var_name = variable.op.name.replace( first_stage_feature_extractor_scope + '/', '') variables_to_restore[var_name] = variable if variable.op.name.startswith( second_stage_feature_extractor_scope): var_name = variable.op.name.replace( second_stage_feature_extractor_scope + '/InceptionResnetV2/Repeat', 'InceptionResnetV2/Repeat_2') var_name = var_name.replace( second_stage_feature_extractor_scope + '/', '') variables_to_restore[var_name] = variable return variables_to_restore