コード例 #1
0
ファイル: lib_tfsampling.py プロジェクト: scienart200/magenta
  def __init__(self, chkpt_path, placeholders=None):
    self.chkpt_path = chkpt_path
    self.hparams = lib_hparams.load_hparams(chkpt_path)
    if placeholders is None:
      self.placeholders = self.get_placeholders()
    else:
      self.placeholders = placeholders

    self.build_sample_graph()
    self.sess = self.instantiate_sess_and_restore_checkpoint()
コード例 #2
0
ファイル: lib_graph.py プロジェクト: umbanhowar/coconet
def load_checkpoint(path):
    """Builds graph, loads checkpoint, and returns wrapped model."""
    print('Loading checkpoint from', path)
    hparams = lib_hparams.load_hparams(path)
    model = build_graph(is_training=False, hparams=hparams)
    wmodel = lib_tfutil.WrappedModel(model, model.loss.graph, hparams)
    with wmodel.graph.as_default():
        wmodel.sess = tf.Session()
        saver = tf.train.Saver()
        tf.logging.info('loading checkpoint %s', path)
        chkpt_path = os.path.join(path, 'best_model.ckpt')
        saver.restore(wmodel.sess, chkpt_path)
    return wmodel
コード例 #3
0
ファイル: lib_graph.py プロジェクト: czhuang/magenta-autofill
def load_checkpoint(path, instantiate_sess=True):
  """Builds graph, loads checkpoint, and returns wrapped model."""
  tf.logging.info('Loading checkpoint from %s', path)
  hparams = lib_hparams.load_hparams(path)
  model = build_graph(is_training=False, hparams=hparams)
  wmodel = lib_tfutil.WrappedModel(model, model.loss.graph, hparams)
  if not instantiate_sess:
    return wmodel
  with wmodel.graph.as_default():
    wmodel.sess = tf.Session()
    saver = tf.train.Saver()
    tf.logging.info('loading checkpoint %s', path)
    chkpt_path = os.path.join(path, 'best_model.ckpt')
    saver.restore(wmodel.sess, chkpt_path)
  return wmodel
コード例 #4
0
ファイル: lib_tfsampling.py プロジェクト: jrysnrt/magenta
  def __init__(self, chkpt_path, placeholders=None):
    """Initializes inputs for the Coconet sampling graph.

    Does not build or restore the graph. That happens lazily if you call run(),
    or explicitly using instantiate_sess_and_restore_checkpoint.

    Args:
      chkpt_path: Checkpoint directory for loading the model.
          Uses the latest checkpoint.
      placeholders: Optional placeholders.
    """
    self.chkpt_path = chkpt_path
    self.hparams = lib_hparams.load_hparams(chkpt_path)
    if placeholders is None:
      self.placeholders = self.get_placeholders()
    else:
      self.placeholders = placeholders
    self.samples = None
    self.sess = None
コード例 #5
0
  def __init__(self, chkpt_path, placeholders=None):
    """Initializes inputs for the Coconet sampling graph.

    Does not build or restore the graph. That happens lazily if you call run(),
    or explicitly using instantiate_sess_and_restore_checkpoint.

    Args:
      chkpt_path: Checkpoint directory for loading the model.
          Uses the latest checkpoint.
      placeholders: Optional placeholders.
    """
    self.chkpt_path = chkpt_path
    self.hparams = lib_hparams.load_hparams(chkpt_path)
    if placeholders is None:
      self.placeholders = self.get_placeholders()
    else:
      self.placeholders = placeholders
    self.samples = None
    self.sess = None