def replace_variable_values_with_moving_averages(graph, current_checkpoint_file, new_checkpoint_file, no_ema_collection=None): """Replaces variable values in the checkpoint with their moving averages. If the current checkpoint has shadow variables maintaining moving averages of the variables defined in the graph, this function generates a new checkpoint where the variables contain the values of their moving averages. Args: graph: a tf.Graph object. current_checkpoint_file: a checkpoint containing both original variables and their moving averages. new_checkpoint_file: file path to write a new checkpoint. no_ema_collection: A list of namescope substrings to match the variables to eliminate EMA. """ with graph.as_default(): variable_averages = tf.train.ExponentialMovingAverage(0.0) ema_variables_to_restore = variable_averages.variables_to_restore() ema_variables_to_restore = config_util.remove_unecessary_ema( ema_variables_to_restore, no_ema_collection) with tf.Session() as sess: read_saver = tf.train.Saver(ema_variables_to_restore) read_saver.restore(sess, current_checkpoint_file) write_saver = tf.train.Saver() write_saver.save(sess, new_checkpoint_file)
def testRemoveUnecessaryEma(self): input_dict = { "expanded_conv_10/project/act_quant/min": 1, "FeatureExtractor/MobilenetV2_2/expanded_conv_5/expand/act_quant/min": 2, "expanded_conv_10/expand/BatchNorm/gamma/min/ExponentialMovingAverage": 3, "expanded_conv_3/depthwise/BatchNorm/beta/max/ExponentialMovingAverage": 4, "BoxPredictor_1/ClassPredictor_depthwise/act_quant": 5 } no_ema_collection = ["/min", "/max"] output_dict = { "expanded_conv_10/project/act_quant/min": 1, "FeatureExtractor/MobilenetV2_2/expanded_conv_5/expand/act_quant/min": 2, "expanded_conv_10/expand/BatchNorm/gamma/min": 3, "expanded_conv_3/depthwise/BatchNorm/beta/max": 4, "BoxPredictor_1/ClassPredictor_depthwise/act_quant": 5 } self.assertEqual( output_dict, config_util.remove_unecessary_ema(input_dict, no_ema_collection))