Example #1
0
def export(checkpoint, destination, use_tf_sampling):
    model = None
    if use_tf_sampling:
        model = lib_tfsampling.CoconetSampleGraph(checkpoint)
        model.instantiate_sess_and_restore_checkpoint()
    else:
        model = lib_graph.load_checkpoint(checkpoint)
    tf.logging.info('Loaded graph.')
    lib_saved_model.export_saved_model(model, destination,
                                       [tf.saved_model.tag_constants.SERVING],
                                       use_tf_sampling)
Example #2
0
  def __init__(self, checkpoint_path):
    """Initializes Generator with a wrapped model and strategy name.

    Args:
      checkpoint_path: A string that gives the full path to the folder that
          holds the checkpoint.
    """
    self.sampler = lib_tfsampling.CoconetSampleGraph(checkpoint_path)
    self.hparams = self.sampler.hparams
    self.endecoder = lib_pianoroll.get_pianoroll_encoder_decoder(self.hparams)

    self._time_taken = None
    self._pianorolls = None
Example #3
0
def main(unused_argv):
    if FLAGS.checkpoint is None or not FLAGS.checkpoint:
        raise ValueError('Need to provide a path to checkpoint directory.')
    if FLAGS.destination is None or not FLAGS.destination:
        raise ValueError(
            'Need to provide a destination directory for the SavedModel.')
    model = None
    if FLAGS.use_tf_sampling:
        model = lib_tfsampling.CoconetSampleGraph(FLAGS.checkpoint)
        model.instantiate_sess_and_restore_checkpoint()
    else:
        model = lib_graph.load_checkpoint(FLAGS.checkpoint)
    tf.logging.info('Loaded graph.')
    lib_saved_model.export_saved_model(model, FLAGS.destination,
                                       [tf.saved_model.tag_constants.SERVING],
                                       FLAGS.use_tf_sampling)
    tf.logging.info('Exported SavedModel to %s.', FLAGS.destination)