def testCollectionGetVariableInScope(self): with tf.variable_scope("prefix") as s1: tf.get_variable("a", shape=[1], collections=["test"], trainable=False) self.assertEqual(len(snt.get_variables_in_scope(s1)), 0) self.assertEqual(len(snt.get_variables_in_scope(s1, collection="test2")), 0) self.assertEqual(len(snt.get_variables_in_scope(s1, collection="test")), 1)
def testScopeQuery(self): with tf.variable_scope("prefix") as s1: v1 = tf.get_variable("a", shape=[3, 4]) with tf.variable_scope("prefix_with_more_stuff") as s2: v2 = tf.get_variable("b", shape=[5, 6]) v3 = tf.get_variable("c", shape=[7]) # get_variables_in_scope should add a "/" to only search that scope, not # any others which share the same prefix. self.assertEqual(snt.get_variables_in_scope(s1), (v1,)) self.assertEqual(set(snt.get_variables_in_scope(s2)), {v2, v3})
def testScopeQuery(self): with tf.variable_scope("prefix") as s1: v1 = tf.get_variable("a", shape=[3, 4]) with tf.variable_scope("prefix_with_more_stuff") as s2: v2 = tf.get_variable("b", shape=[5, 6]) v3 = tf.get_variable("c", shape=[7]) # get_variables_in_scope should add a "/" to only search that scope, not # any others which share the same prefix. self.assertEqual(snt.get_variables_in_scope(s1), (v1,)) self.assertEqual(set(snt.get_variables_in_scope(s2)), {v2, v3}) self.assertEqual(snt.get_variables_in_scope(s1.name), (v1,)) self.assertEqual(set(snt.get_variables_in_scope(s2.name)), {v2, v3})
def test_variable_creation(self): self._leo(self._problem) encoder_variables = snt.get_variables_in_scope("leo/encoder") self.assertNotEmpty(encoder_variables) relation_network_variables = snt.get_variables_in_scope( "leo/relation_network") self.assertNotEmpty(relation_network_variables) decoder_variables = snt.get_variables_in_scope("leo/decoder") self.assertNotEmpty(decoder_variables) inner_lr = snt.get_variables_in_scope("leo/leo_inner") self.assertNotEmpty(inner_lr) finetuning_lr = snt.get_variables_in_scope("leo/finetuning") self.assertNotEmpty(finetuning_lr) self.assertSameElements( encoder_variables + relation_network_variables + decoder_variables + inner_lr + finetuning_lr, self._leo.trainable_variables)
def train(train_log_dir, checkpoint_dir, eval_every_n_steps=10, num_steps=3000): dataset_fn = datasets.mnist.TinyMnist w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess meta_objectives = [] meta_objectives.append( meta_objective.linear_regression.LinearRegressionMetaObjective) meta_objectives.append(meta_objective.sklearn.LogisticRegression) checkpoint_vars, train_one_step_op, ( base_model, dataset) = evaluation.construct_evaluation_graph( theta_process_fn=theta_process_fn, w_learner_fn=w_learner_fn, dataset_fn=dataset_fn, meta_objectives=meta_objectives) batch = dataset() pre_logit, outputs = base_model(batch) global_step = tf.train.get_or_create_global_step() var_list = list( snt.get_variables_in_module(base_model, tf.GraphKeys.TRAINABLE_VARIABLES)) tf.logging.info("all vars") for v in tf.all_variables(): tf.logging.info(" %s" % str(v)) global_step = tf.train.get_global_step() accumulate_global_step = global_step.assign_add(1) reset_global_step = global_step.assign(0) train_op = tf.group(train_one_step_op, accumulate_global_step, name="train_op") summary_op = tf.summary.merge_all() file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"]) if checkpoint_dir: str_var_list = checkpoint_utils.list_variables(checkpoint_dir) name_to_v_map = {v.op.name: v for v in tf.all_variables()} var_list = [ name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map ] saver = tf.train.Saver(var_list) missed_variables = [ v.op.name for v in set( snt.get_variables_in_scope("LocalWeightUpdateProcess", tf.GraphKeys.GLOBAL_VARIABLES)) - set(var_list) ] assert len(missed_variables) == 0, "Missed a theta variable." hooks = [] with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess: # global step should be restored from the evals job checkpoint or zero for fresh. step = sess.run(global_step) if step == 0 and checkpoint_dir: tf.logging.info("force restore") saver.restore(sess, checkpoint_dir) tf.logging.info("force restore done") sess.run(reset_global_step) step = sess.run(global_step) while step < num_steps: if step % eval_every_n_steps == 0: s, _, step = sess.run([summary_op, train_op, global_step]) file_writer.add_summary(s, step) else: _, step = sess.run([train_op, global_step])
def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000): dataset_fn = datasets.mnist.TinyMnist w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess meta_objectives = [] meta_objectives.append( meta_objective.linear_regression.LinearRegressionMetaObjective) meta_objectives.append(meta_objective.sklearn.LogisticRegression) checkpoint_vars, train_one_step_op, ( base_model, dataset) = evaluation.construct_evaluation_graph( theta_process_fn=theta_process_fn, w_learner_fn=w_learner_fn, dataset_fn=dataset_fn, meta_objectives=meta_objectives) batch = dataset() pre_logit, outputs = base_model(batch) global_step = tf.train.get_or_create_global_step() var_list = list( snt.get_variables_in_module(base_model, tf.GraphKeys.TRAINABLE_VARIABLES)) tf.logging.info("all vars") for v in tf.all_variables(): tf.logging.info(" %s" % str(v)) global_step = tf.train.get_global_step() accumulate_global_step = global_step.assign_add(1) reset_global_step = global_step.assign(0) train_op = tf.group( train_one_step_op, accumulate_global_step, name="train_op") summary_op = tf.summary.merge_all() file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"]) if checkpoint: str_var_list = checkpoint_utils.list_variables(checkpoint) name_to_v_map = {v.op.name: v for v in tf.all_variables()} var_list = [ name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map ] saver = tf.train.Saver(var_list) missed_variables = [ v.op.name for v in set( snt.get_variables_in_scope("LocalWeightUpdateProcess", tf.GraphKeys.GLOBAL_VARIABLES)) - set(var_list) ] assert len(missed_variables) == 0, "Missed a theta variable." hooks = [] with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess: # global step should be restored from the evals job checkpoint or zero for fresh. step = sess.run(global_step) if step == 0 and checkpoint: tf.logging.info("force restore") saver.restore(sess, checkpoint) tf.logging.info("force restore done") sess.run(reset_global_step) step = sess.run(global_step) while step < num_steps: if step % eval_every_n_steps == 0: s, _, step = sess.run([summary_op, train_op, global_step]) file_writer.add_summary(s, step) else: _, step = sess.run([train_op, global_step])