Beispiel #1
0
def replace_variable_values_with_moving_averages(graph,
                                                 current_checkpoint_file,
                                                 new_checkpoint_file,
                                                 no_ema_collection=None):
    """Replaces variable values in the checkpoint with their moving averages.

  If the current checkpoint has shadow variables maintaining moving averages of
  the variables defined in the graph, this function generates a new checkpoint
  where the variables contain the values of their moving averages.

  Args:
    graph: a tf.Graph object.
    current_checkpoint_file: a checkpoint containing both original variables and
      their moving averages.
    new_checkpoint_file: file path to write a new checkpoint.
    no_ema_collection: A list of namescope substrings to match the variables
      to eliminate EMA.
  """
    with graph.as_default():
        variable_averages = tf.train.ExponentialMovingAverage(0.0)
        ema_variables_to_restore = variable_averages.variables_to_restore()
        ema_variables_to_restore = config_util.remove_unecessary_ema(
            ema_variables_to_restore, no_ema_collection)
        with tf.Session() as sess:
            read_saver = tf.train.Saver(ema_variables_to_restore)
            read_saver.restore(sess, current_checkpoint_file)
            write_saver = tf.train.Saver()
            write_saver.save(sess, new_checkpoint_file)
    def testRemoveUnecessaryEma(self):
        input_dict = {
            "expanded_conv_10/project/act_quant/min":
            1,
            "FeatureExtractor/MobilenetV2_2/expanded_conv_5/expand/act_quant/min":
            2,
            "expanded_conv_10/expand/BatchNorm/gamma/min/ExponentialMovingAverage":
            3,
            "expanded_conv_3/depthwise/BatchNorm/beta/max/ExponentialMovingAverage":
            4,
            "BoxPredictor_1/ClassPredictor_depthwise/act_quant":
            5
        }

        no_ema_collection = ["/min", "/max"]

        output_dict = {
            "expanded_conv_10/project/act_quant/min": 1,
            "FeatureExtractor/MobilenetV2_2/expanded_conv_5/expand/act_quant/min":
            2,
            "expanded_conv_10/expand/BatchNorm/gamma/min": 3,
            "expanded_conv_3/depthwise/BatchNorm/beta/max": 4,
            "BoxPredictor_1/ClassPredictor_depthwise/act_quant": 5
        }

        self.assertEqual(
            output_dict,
            config_util.remove_unecessary_ema(input_dict, no_ema_collection))