Ejemplo n.º 1
0
def get_checkpoint_dumper(model_type, checkpoint_file, output_dir, remove_variables_regex):
  """Returns Checkpoint dumper instance for a given model type.
  Parameters
  ----------
  model_type : str
      Type of deeplearning framework
  checkpoint_file : str
      Path to checkpoint file
  output_dir : str
      Path to output directory
  remove_variables_regex : str
      Regex for variables to be ignored
  Returns
  -------
  (TensorflowCheckpointDumper, PytorchCheckpointDumper)
      Checkpoint Dumper Instance for corresponding model type
  Raises
  ------
  Error
      If particular model type is not supported
  """
  if model_type == 'tensorflow':
    from tensorflow_checkpoint_dumper import TensorflowCheckpointDumper

    return TensorflowCheckpointDumper(
      checkpoint_file, output_dir, remove_variables_regex)
  elif model_type == 'pytorch':
    from pytorch_checkpoint_dumper import PytorchCheckpointDumper

    return PytorchCheckpointDumper(
      checkpoint_file, output_dir, remove_variables_regex)
  else:
    raise ValueError('Currently, "{}" models are not supported'.format(model_type))
Ejemplo n.º 2
0
import argparse
from tensorflow_checkpoint_dumper import TensorflowCheckpointDumper

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--checkpoint_file',
      type=str,
      required=True,
      help='Path to the model checkpoint')
  parser.add_argument(
      '--output_dir',
      type=str,
      required=True,
      help='The output directory where to store the converted weights')
  parser.add_argument(
      '--remove_variables_regex',
      type=str,
      default='',
      help='A regular expression to match against variable names that should '
      'not be included')
  FLAGS, unparsed = parser.parse_known_args()

  if unparsed:
    parser.print_help()
    print('Unrecognized flags: ', unparsed)
    exit(-1)

  checkpoint_dumper = TensorflowCheckpointDumper(FLAGS.checkpoint_file, FLAGS.output_dir, FLAGS.remove_variables_regex)
  checkpoint_dumper.build_and_dump_vars()