コード例 #1
0
ファイル: util_test.py プロジェクト: ccchang0111/sonnet
  def testCollectionGetVariableInScope(self):
    with tf.variable_scope("prefix") as s1:
      tf.get_variable("a", shape=[1], collections=["test"], trainable=False)

    self.assertEqual(len(snt.get_variables_in_scope(s1)), 0)
    self.assertEqual(len(snt.get_variables_in_scope(s1, collection="test2")), 0)
    self.assertEqual(len(snt.get_variables_in_scope(s1, collection="test")), 1)
コード例 #2
0
ファイル: util_test.py プロジェクト: zxshinxz/sonnet
  def testCollectionGetVariableInScope(self):
    with tf.variable_scope("prefix") as s1:
      tf.get_variable("a", shape=[1], collections=["test"], trainable=False)

    self.assertEqual(len(snt.get_variables_in_scope(s1)), 0)
    self.assertEqual(len(snt.get_variables_in_scope(s1, collection="test2")), 0)
    self.assertEqual(len(snt.get_variables_in_scope(s1, collection="test")), 1)
コード例 #3
0
ファイル: util_test.py プロジェクト: zhuohuwu0603/sonnet
  def testScopeQuery(self):
    with tf.variable_scope("prefix") as s1:
      v1 = tf.get_variable("a", shape=[3, 4])
    with tf.variable_scope("prefix_with_more_stuff") as s2:
      v2 = tf.get_variable("b", shape=[5, 6])
      v3 = tf.get_variable("c", shape=[7])

    # get_variables_in_scope should add a "/" to only search that scope, not
    # any others which share the same prefix.
    self.assertEqual(snt.get_variables_in_scope(s1), (v1,))
    self.assertEqual(set(snt.get_variables_in_scope(s2)), {v2, v3})
コード例 #4
0
ファイル: util_test.py プロジェクト: geniusjiqing/sonnet
  def testScopeQuery(self):
    with tf.variable_scope("prefix") as s1:
      v1 = tf.get_variable("a", shape=[3, 4])
    with tf.variable_scope("prefix_with_more_stuff") as s2:
      v2 = tf.get_variable("b", shape=[5, 6])
      v3 = tf.get_variable("c", shape=[7])

    # get_variables_in_scope should add a "/" to only search that scope, not
    # any others which share the same prefix.
    self.assertEqual(snt.get_variables_in_scope(s1), (v1,))
    self.assertEqual(set(snt.get_variables_in_scope(s2)), {v2, v3})
    self.assertEqual(snt.get_variables_in_scope(s1.name), (v1,))
    self.assertEqual(set(snt.get_variables_in_scope(s2.name)), {v2, v3})
コード例 #5
0
 def test_variable_creation(self):
     self._leo(self._problem)
     encoder_variables = snt.get_variables_in_scope("leo/encoder")
     self.assertNotEmpty(encoder_variables)
     relation_network_variables = snt.get_variables_in_scope(
         "leo/relation_network")
     self.assertNotEmpty(relation_network_variables)
     decoder_variables = snt.get_variables_in_scope("leo/decoder")
     self.assertNotEmpty(decoder_variables)
     inner_lr = snt.get_variables_in_scope("leo/leo_inner")
     self.assertNotEmpty(inner_lr)
     finetuning_lr = snt.get_variables_in_scope("leo/finetuning")
     self.assertNotEmpty(finetuning_lr)
     self.assertSameElements(
         encoder_variables + relation_network_variables +
         decoder_variables + inner_lr + finetuning_lr,
         self._leo.trainable_variables)
コード例 #6
0
def train(train_log_dir,
          checkpoint_dir,
          eval_every_n_steps=10,
          num_steps=3000):
    dataset_fn = datasets.mnist.TinyMnist
    w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner
    theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess

    meta_objectives = []
    meta_objectives.append(
        meta_objective.linear_regression.LinearRegressionMetaObjective)
    meta_objectives.append(meta_objective.sklearn.LogisticRegression)

    checkpoint_vars, train_one_step_op, (
        base_model, dataset) = evaluation.construct_evaluation_graph(
            theta_process_fn=theta_process_fn,
            w_learner_fn=w_learner_fn,
            dataset_fn=dataset_fn,
            meta_objectives=meta_objectives)
    batch = dataset()
    pre_logit, outputs = base_model(batch)

    global_step = tf.train.get_or_create_global_step()
    var_list = list(
        snt.get_variables_in_module(base_model,
                                    tf.GraphKeys.TRAINABLE_VARIABLES))

    tf.logging.info("all vars")
    for v in tf.all_variables():
        tf.logging.info("   %s" % str(v))
    global_step = tf.train.get_global_step()
    accumulate_global_step = global_step.assign_add(1)
    reset_global_step = global_step.assign(0)

    train_op = tf.group(train_one_step_op,
                        accumulate_global_step,
                        name="train_op")

    summary_op = tf.summary.merge_all()

    file_writer = summary_utils.LoggingFileWriter(train_log_dir,
                                                  regexes=[".*"])
    if checkpoint_dir:
        str_var_list = checkpoint_utils.list_variables(checkpoint_dir)
        name_to_v_map = {v.op.name: v for v in tf.all_variables()}
        var_list = [
            name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map
        ]
        saver = tf.train.Saver(var_list)
        missed_variables = [
            v.op.name for v in set(
                snt.get_variables_in_scope("LocalWeightUpdateProcess",
                                           tf.GraphKeys.GLOBAL_VARIABLES)) -
            set(var_list)
        ]
        assert len(missed_variables) == 0, "Missed a theta variable."

    hooks = []

    with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess:

        # global step should be restored from the evals job checkpoint or zero for fresh.
        step = sess.run(global_step)

        if step == 0 and checkpoint_dir:
            tf.logging.info("force restore")
            saver.restore(sess, checkpoint_dir)
            tf.logging.info("force restore done")
            sess.run(reset_global_step)
            step = sess.run(global_step)

        while step < num_steps:
            if step % eval_every_n_steps == 0:
                s, _, step = sess.run([summary_op, train_op, global_step])
                file_writer.add_summary(s, step)
            else:
                _, step = sess.run([train_op, global_step])
コード例 #7
0
ファイル: run_eval.py プロジェクト: ALISCIFP/models
def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
  dataset_fn = datasets.mnist.TinyMnist
  w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner
  theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess

  meta_objectives = []
  meta_objectives.append(
      meta_objective.linear_regression.LinearRegressionMetaObjective)
  meta_objectives.append(meta_objective.sklearn.LogisticRegression)

  checkpoint_vars, train_one_step_op, (
      base_model, dataset) = evaluation.construct_evaluation_graph(
          theta_process_fn=theta_process_fn,
          w_learner_fn=w_learner_fn,
          dataset_fn=dataset_fn,
          meta_objectives=meta_objectives)
  batch = dataset()
  pre_logit, outputs = base_model(batch)

  global_step = tf.train.get_or_create_global_step()
  var_list = list(
      snt.get_variables_in_module(base_model, tf.GraphKeys.TRAINABLE_VARIABLES))

  tf.logging.info("all vars")
  for v in tf.all_variables():
    tf.logging.info("   %s" % str(v))
  global_step = tf.train.get_global_step()
  accumulate_global_step = global_step.assign_add(1)
  reset_global_step = global_step.assign(0)

  train_op = tf.group(
      train_one_step_op, accumulate_global_step, name="train_op")

  summary_op = tf.summary.merge_all()

  file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"])
  if checkpoint:
    str_var_list = checkpoint_utils.list_variables(checkpoint)
    name_to_v_map = {v.op.name: v for v in tf.all_variables()}
    var_list = [
        name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map
    ]
    saver = tf.train.Saver(var_list)
    missed_variables = [
        v.op.name for v in set(
            snt.get_variables_in_scope("LocalWeightUpdateProcess",
                                       tf.GraphKeys.GLOBAL_VARIABLES)) -
        set(var_list)
    ]
    assert len(missed_variables) == 0, "Missed a theta variable."

  hooks = []

  with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess:

    # global step should be restored from the evals job checkpoint or zero for fresh.
    step = sess.run(global_step)

    if step == 0 and checkpoint:
      tf.logging.info("force restore")
      saver.restore(sess, checkpoint)
      tf.logging.info("force restore done")
      sess.run(reset_global_step)
      step = sess.run(global_step)

    while step < num_steps:
      if step % eval_every_n_steps == 0:
        s, _, step = sess.run([summary_op, train_op, global_step])
        file_writer.add_summary(s, step)
      else:
        _, step = sess.run([train_op, global_step])