Пример #1
0
def bootstrap(working_dir, params):
    """Initialize a tf.Estimator run with random initial weights.

  Args:
    working_dir: The directory where tf.estimator will drop logs,
      checkpoints, and so on
    params: hyperparams of the model.
  """
    # Forge an initial checkpoint with the name that subsequent Estimator will
    # expect to find.
    estimator_initial_checkpoint_name = 'model.ckpt-1'
    save_file = os.path.join(working_dir, estimator_initial_checkpoint_name)
    sess = tf.Session()
    with sess.graph.as_default():
        input_features, labels = get_inference_input(params)
        dualnet_model.model_fn(input_features, labels,
                               tf.estimator.ModeKeys.PREDICT, params)
        sess.run(tf.global_variables_initializer())
        tf.train.Saver().save(sess, save_file)
Пример #2
0
def bootstrap(working_dir, params):
  """Initialize a tf.Estimator run with random initial weights.

  Args:
    working_dir: The directory where tf.estimator will drop logs,
      checkpoints, and so on
    params: hyperparams of the model.
  """
  # Forge an initial checkpoint with the name that subsequent Estimator will
  # expect to find.
  estimator_initial_checkpoint_name = 'model.ckpt-1'
  save_file = os.path.join(working_dir,
                           estimator_initial_checkpoint_name)
  sess = tf.Session()
  with sess.graph.as_default():
    input_features, labels = get_inference_input(params)
    dualnet_model.model_fn(
        input_features, labels, tf.estimator.ModeKeys.PREDICT, params)
    sess.run(tf.global_variables_initializer())
    tf.train.Saver().save(sess, save_file)
 def initialize_graph(self):
   """Initialize the graph with saved model."""
   with self.sess.graph.as_default():
     input_features, labels = get_inference_input(self.hparams)
     estimator_spec = dualnet_model.model_fn(
         input_features, labels, tf.estimator.ModeKeys.PREDICT, self.hparams)
     self.inference_input = input_features
     self.inference_output = estimator_spec.predictions
     if self.save_file is not None:
       self.initialize_weights(self.save_file)
     else:
       self.sess.run(tf.global_variables_initializer())
Пример #4
0
 def initialize_graph(self):
   """Initialize the graph with saved model."""
   with self.sess.graph.as_default():
     input_features, labels = get_inference_input(self.hparams)
     estimator_spec = dualnet_model.model_fn(
         input_features, labels, tf.estimator.ModeKeys.PREDICT, self.hparams)
     self.inference_input = input_features
     self.inference_output = estimator_spec.predictions
     if self.save_file is not None:
       self.initialize_weights(self.save_file)
     else:
       self.sess.run(tf.global_variables_initializer())