コード例 #1
0
def load_checkpoint(use_placeholder=False, session=None):
  dataset = build("data")
  model = build("model")
  if use_placeholder:
    inputs = dataset.get_placeholders()
  else:
    inputs = dataset()

  info = model.eval(inputs)
  if session is None:
    session = tf.Session()
  saver = tf.train.Saver()
  checkpoint_dir = get_checkpoint_dir()
  checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
  saver.restore(session, checkpoint_file)

  print('Successfully restored Checkpoint "{}"'.format(checkpoint_file))
  # print variables
  variables = tf.global_variables() + tf.local_variables()
  for row in snt.format_variables(variables, join_lines=False):
    print(row)

  return {
      "session": session,
      "model": model,
      "info": info,
      "inputs": inputs,
      "dataset": dataset,
  }
コード例 #2
0
def print_model(model: snt.Module):
    print(f'{model.__class__.__name__} : {model.name}\n')
    print(snt.format_variables(model.variables))
    n_params = np.sum([np.prod(v.shape) for v in model.variables])
    trainable_params = np.sum(
        [np.prod(v.shape) for v in model.trainable_variables])
    print(f'\nParams: {trainable_params} trainable out of {n_params}')
コード例 #3
0
 def testFormatVariables(self):
     with tf.variable_scope("m1"):
         v1 = tf.get_variable("v1", shape=[3, 4])
     with tf.device("/gpu"):
         with tf.variable_scope("m2"):
             v2 = tf.get_local_variable("v2", shape=[5, 6])
     self.assertEqual(snt.format_variables([v2, v1]),
                      _EXPECTED_FORMATTED_VARIABLE_LIST)
コード例 #4
0
ファイル: util_test.py プロジェクト: ccchang0111/sonnet
 def testFormatVariables(self, use_resource, expected):
   with tf.variable_scope("m1"):
     v1 = tf.get_variable("v1", shape=[3, 4], use_resource=use_resource)
   with tf.device("/gpu"):
     with tf.variable_scope("m2"):
       v2 = tf.get_local_variable(
           "v2", shape=[5, 6], use_resource=use_resource)
   self.assertEqual(snt.format_variables([v2, v1]), expected)
コード例 #5
0
ファイル: util_test.py プロジェクト: geniusjiqing/sonnet
 def testFormatVariables(self):
   with tf.variable_scope("m1"):
     v1 = tf.get_variable("v1", shape=[3, 4])
   with tf.device("/gpu"):
     with tf.variable_scope("m2"):
       v2 = tf.get_local_variable("v2", shape=[5, 6])
   self.assertEqual(snt.format_variables([v2, v1]),
                    _EXPECTED_FORMATTED_VARIABLE_LIST)
コード例 #6
0
 def testFormatVariables(self, use_resource, expected):
   with tf.variable_scope("m1"):
     v1 = tf.get_variable("v1", shape=[3, 4], use_resource=use_resource)
   with tf.device("/gpu"):
     with tf.variable_scope("m2"):
       v2 = tf.get_local_variable(
           "v2", shape=[5, 6], use_resource=use_resource)
   self.assertEqual(snt.format_variables([v2, v1]), expected)
コード例 #7
0
ファイル: util_test.py プロジェクト: zxshinxz/sonnet
 def testFormatVariables(self):
   with tf.variable_scope("m1"):
     v1 = tf.get_variable("v1", shape=[3, 4])
   with tf.variable_scope("m2"):
     v2 = tf.get_variable("v2", shape=[5, 6])
   self.assertEquals(snt.format_variables([v2, v1]),
                     ("Variable  Shape  Type\n"
                      "m1/v1:0   3x4    tf.float32\n"
                      "m2/v2:0   5x6    tf.float32"))