def convert_checkpoint(bert_config, output_path, v1_checkpoint): """Converts a V1 checkpoint into an OO V2 checkpoint.""" output_dir, _ = os.path.split(output_path) # Create a temporary V1 name-converted checkpoint in the output directory. temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1") temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt") tf1_checkpoint_converter_lib.convert( checkpoint_from_path=v1_checkpoint, checkpoint_to_path=temporary_checkpoint, num_heads=bert_config.num_attention_heads, name_replacements=tf1_checkpoint_converter_lib. BERT_V2_NAME_REPLACEMENTS, permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS, exclude_patterns=["adam", "Adam"]) # Create a V2 checkpoint from the temporary checkpoint. model = _create_bert_model(bert_config) tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint, output_path) # Clean up the temporary checkpoint, if it exists. try: tf.io.gfile.rmtree(temporary_checkpoint_dir) except tf.errors.OpError: # If it doesn't exist, we don't need to clean it up; continue. pass
def main(_): exclude_patterns = None if FLAGS.exclude_patterns: exclude_patterns = FLAGS.exclude_patterns.split(",") if FLAGS.create_v2_checkpoint: name_replacements = tf1_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS permutations = tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS else: name_replacements = tf1_checkpoint_converter_lib.BERT_NAME_REPLACEMENTS permutations = tf1_checkpoint_converter_lib.BERT_PERMUTATIONS tf1_checkpoint_converter_lib.convert(FLAGS.checkpoint_from_path, FLAGS.checkpoint_to_path, FLAGS.num_heads, name_replacements, permutations, exclude_patterns)
def convert_checkpoint(bert_config, output_path, v1_checkpoint): """Converts a V1 checkpoint into an OO V2 checkpoint.""" output_dir, _ = os.path.split(output_path) <<<<<<< HEAD ======= tf.io.gfile.makedirs(output_dir) >>>>>>> a811a3b7e640722318ad868c99feddf3f3063e36 # Create a temporary V1 name-converted checkpoint in the output directory. temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1") temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt") tf1_checkpoint_converter_lib.convert( checkpoint_from_path=v1_checkpoint, checkpoint_to_path=temporary_checkpoint, num_heads=bert_config.num_attention_heads, name_replacements=tf1_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS, permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS, exclude_patterns=["adam", "Adam"]) # Create a V2 checkpoint from the temporary checkpoint. model = _create_bert_model(bert_config) tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint, output_path) # Clean up the temporary checkpoint, if it exists. try: tf.io.gfile.rmtree(temporary_checkpoint_dir) except tf.errors.OpError: # If it doesn't exist, we don't need to clean it up; continue. pass
def main(args): convert(args.input_checkpoint, args.output_checkpoint, args.num_heads, args.name_replacements, args.name_permutations)