def _monkey_patch_context(iteration_step_scope, scoped_summary, trainable_vars): """Monkey-patches global attributes with subnetwork-specifics ones.""" old_get_global_step_fn = tf_compat.v1.train.get_global_step old_get_or_create_global_step_fn = tf_compat.v1.train.get_or_create_global_step old_trainable_vars = tf_compat.v1.trainable_variables() def iteration_step(graph=None): graph = graph or tf_compat.v1.get_default_graph() with graph.as_default() as g, g.name_scope(None): with tf_compat.v1.variable_scope(iteration_step_scope, reuse=tf_compat.v1.AUTO_REUSE): return tf_compat.v1.get_variable( "iteration_step", shape=[], initializer=tf_compat.v1.zeros_initializer(), trainable=False, dtype=tf.int64) # monkey-patch global attributes. setattr(tf_compat.v1.train, "get_global_step", iteration_step) setattr(tf_compat.v1.train, "get_or_create_global_step", iteration_step) setattr(tf_v1.train, "get_global_step", iteration_step) setattr(tf_v1.train, "get_or_create_global_step", iteration_step) setattr(tf.train, "get_global_step", iteration_step) setattr(tf.train, "get_or_create_global_step", iteration_step) setattr(train, "get_global_step", iteration_step) setattr(training_util, "get_global_step", iteration_step) setattr(train, "get_or_create_global_step", iteration_step) setattr(training_util, "get_or_create_global_step", iteration_step) # The TPUEmbedding uses dummy variables to coordinate sending and receiving # gradients. If no gradients are computed on these dummy variables, the # TPUEmbedding will throw an error. embedding_variables = tf_compat.v1.get_collection( "tpu_embedding_dummy_table_variables") _set_trainable_variables(trainable_vars + embedding_variables) try: with monkey_patched_summaries(scoped_summary): yield finally: # Revert monkey-patches. new_trainable_vars = _new_trainable_variables(trainable_vars) _set_trainable_variables(old_trainable_vars + new_trainable_vars) setattr(training_util, "get_or_create_global_step", old_get_or_create_global_step_fn) setattr(train, "get_or_create_global_step", old_get_or_create_global_step_fn) setattr(training_util, "get_global_step", old_get_global_step_fn) setattr(train, "get_global_step", old_get_global_step_fn) setattr(tf.train, "get_or_create_global_step", old_get_or_create_global_step_fn) setattr(tf.train, "get_global_step", old_get_global_step_fn) setattr(tf_v1.train, "get_or_create_global_step", old_get_or_create_global_step_fn) setattr(tf_v1.train, "get_global_step", old_get_global_step_fn) setattr(tf_compat.v1.train, "get_or_create_global_step", old_get_or_create_global_step_fn) setattr(tf_compat.v1.train, "get_global_step", old_get_global_step_fn)
def _subnetwork_context(iteration_step_scope, scoped_summary): """Monkey-patches global attributes with subnetwork-specifics ones.""" old_get_global_step_fn = tf.train.get_global_step old_get_or_create_global_step_fn = tf.train.get_or_create_global_step def iteration_step(graph=None): del graph with tf.variable_scope(iteration_step_scope, reuse=tf.AUTO_REUSE): return tf.get_variable("iteration_step", shape=[], initializer=tf.zeros_initializer(), trainable=False, dtype=tf.int64) # monkey-patch global attributes. tf.train.get_global_step = iteration_step tf.train.get_or_create_global_step = iteration_step training_util.get_global_step = iteration_step training_util.get_or_create_global_step = iteration_step try: with monkey_patched_summaries(scoped_summary): yield finally: # Revert monkey-patches. training_util.get_or_create_global_step = old_get_or_create_global_step_fn training_util.get_global_step = old_get_global_step_fn tf.train.get_or_create_global_step = old_get_or_create_global_step_fn tf.train.get_global_step = old_get_global_step_fn
def test_monkey_patched_summaries_kwargs(self, summary_maker): summary = summary_maker() before = _summaries() with monkey_patched_summaries(summary): for want, got in zip(before, _summaries()): self.assertNotEqual(want, got) tf.summary.scalar(name="scalar", tensor=1, collections=["collection"], family="family") tf.summary.image(name="image", tensor=1, max_outputs=3, collections=["collection"], family="family") tf.summary.histogram(name="histogram", values=1, collections=["collection"], family="family") tf.summary.audio(name="audio", tensor=1, sample_rate=3, max_outputs=3, collections=["collection"], family="family") want_summary_fn_count = 4 try: tf.contrib.summary.scalar(name="scalar_v2", tensor=1, family="family", step=10) tf.contrib.summary.image(name="image_v2", tensor=1, bad_color=True, max_images=3, family="family", step=10) tf.contrib.summary.histogram(name="histogram_v2", tensor=1, family="family", step=10) tf.contrib.summary.audio(name="audio_v2", tensor=1, sample_rate=3, max_outputs=3, family="family", step=10) want_summary_fn_count += 4 except (AttributeError, ImportError): # TF 2.0 eliminates tf.contrib. pass self.assertEqual(before, _summaries()) self.assertLen(self._get_summary_ops(summary), want_summary_fn_count)
def test_monkey_patched_summaries_args(self, summary_maker): summary = summary_maker() with monkey_patched_summaries(summary): tf.summary.scalar("scalar", 1, ["collection"], "family") tf.summary.image("image", 1, 3, ["collection"], "family") tf.summary.histogram("histogram", 1, ["collection"], "family") tf.summary.audio("audio", 1, 3, 3, ["collection"], "family") tf.contrib.summary.scalar("scalar_v2", 1, "family", 10) tf.contrib.summary.image("image_v2", 1, True, 3, "family", 10) tf.contrib.summary.histogram("histogram_v2", 1, "family", 10) tf.contrib.summary.audio("audio_v2", 1, 3, 3, "family", 10) self.assertLen(summary.merge_all(), 8)
def test_monkey_patched_summaries_args(self): summary = _ScopedSummary(self.test_subdirectory, global_step=10) with monkey_patched_summaries(summary): tf.summary.scalar("scalar", 1, ["collection"], "family") tf.summary.image("image", 1, 3, ["collection"], "family") tf.summary.histogram("histogram", 1, ["collection"], "family") tf.summary.audio("audio", 1, 3, 3, ["collection"], "family") tf.contrib.summary.scalar("scalar_v2", 1, "family", 10) tf.contrib.summary.image("image_v2", 1, True, 3, "family", 10) tf.contrib.summary.histogram("histogram_v2", 1, "family", 10) tf.contrib.summary.audio("audio_v2", 1, 3, 3, "family", 10) self.assertLen(summary.merge_all(), 8)
def _monkey_patch_context(iteration_step_scope, scoped_summary, trainable_vars): """Monkey-patches global attributes with subnetwork-specifics ones.""" old_get_global_step_fn = tf_compat.v1.train.get_global_step old_get_or_create_global_step_fn = tf_compat.v1.train.get_or_create_global_step old_trainable_vars = tf_compat.v1.trainable_variables() def iteration_step(graph=None): graph = graph or tf_compat.v1.get_default_graph() with graph.as_default() as g, g.name_scope(None): with tf_compat.v1.variable_scope(iteration_step_scope, reuse=tf_compat.v1.AUTO_REUSE): return tf_compat.v1.get_variable( "iteration_step", shape=[], initializer=tf_compat.v1.zeros_initializer(), trainable=False, dtype=tf.int64) # monkey-patch global attributes. setattr(tf_compat.v1.train, "get_global_step", iteration_step) setattr(tf_compat.v1.train, "get_or_create_global_step", iteration_step) setattr(tf.train, "get_global_step", iteration_step) setattr(tf.train, "get_or_create_global_step", iteration_step) setattr(train, "get_global_step", iteration_step) setattr(training_util, "get_global_step", iteration_step) setattr(train, "get_or_create_global_step", iteration_step) setattr(training_util, "get_or_create_global_step", iteration_step) _set_trainable_variables(trainable_vars) try: with monkey_patched_summaries(scoped_summary): yield finally: # Revert monkey-patches. new_trainable_vars = _new_trainable_variables(trainable_vars) _set_trainable_variables(old_trainable_vars + new_trainable_vars) setattr(training_util, "get_or_create_global_step", old_get_or_create_global_step_fn) setattr(train, "get_or_create_global_step", old_get_or_create_global_step_fn) setattr(training_util, "get_global_step", old_get_global_step_fn) setattr(train, "get_global_step", old_get_global_step_fn) setattr(tf.train, "get_or_create_global_step", old_get_or_create_global_step_fn) setattr(tf.train, "get_global_step", old_get_global_step_fn) setattr(tf_compat.v1.train, "get_or_create_global_step", old_get_or_create_global_step_fn) setattr(tf_compat.v1.train, "get_global_step", old_get_global_step_fn)
def test_monkey_patched_summaries_args(self, summary_maker): summary = summary_maker() before = _summaries() with monkey_patched_summaries(summary): self.assertNotEqual(before, _summaries()) tf.summary.scalar("scalar", 1, ["collection"], "family") tf.summary.image("image", 1, 3, ["collection"], "family") tf.summary.histogram("histogram", 1, ["collection"], "family") tf.summary.audio("audio", 1, 3, 3, ["collection"], "family") tf.contrib.summary.scalar("scalar_v2", 1, "family", 10) tf.contrib.summary.image("image_v2", 1, True, 3, "family", 10) tf.contrib.summary.histogram("histogram_v2", 1, "family", 10) tf.contrib.summary.audio("audio_v2", 1, 3, 3, "family", 10) self.assertEqual(before, _summaries()) self.assertLen(self._get_summary_ops(summary), 8)
def test_monkey_patched_summaries_kwargs(self, summary_maker): summary = summary_maker() before = _summaries() with monkey_patched_summaries(summary): self.assertNotEqual(before, _summaries()) tf.summary.scalar(name="scalar", tensor=1, collections=["collection"], family="family") tf.summary.image(name="image", tensor=1, max_outputs=3, collections=["collection"], family="family") tf.summary.histogram(name="histogram", values=1, collections=["collection"], family="family") tf.summary.audio(name="audio", tensor=1, sample_rate=3, max_outputs=3, collections=["collection"], family="family") tf.contrib.summary.scalar(name="scalar_v2", tensor=1, family="family", step=10) tf.contrib.summary.image(name="image_v2", tensor=1, bad_color=True, max_images=3, family="family", step=10) tf.contrib.summary.histogram(name="histogram_v2", tensor=1, family="family", step=10) tf.contrib.summary.audio(name="audio_v2", tensor=1, sample_rate=3, max_outputs=3, family="family", step=10) self.assertEqual(before, _summaries()) self.assertLen(self._get_summary_ops(summary), 8)
def test_monkey_patched_summaries_kwargs(self): summary = _ScopedSummary(self.test_subdirectory, global_step=10) with monkey_patched_summaries(summary): tf.summary.scalar(name="scalar", tensor=1, collections=["collection"], family="family") tf.summary.image(name="image", tensor=1, max_outputs=3, collections=["collection"], family="family") tf.summary.histogram(name="histogram", values=1, collections=["collection"], family="family") tf.summary.audio(name="audio", tensor=1, sample_rate=3, max_outputs=3, collections=["collection"], family="family") tf.contrib.summary.scalar(name="scalar_v2", tensor=1, family="family", step=10) tf.contrib.summary.image(name="image_v2", tensor=1, bad_color=True, max_images=3, family="family", step=10) tf.contrib.summary.histogram(name="histogram_v2", tensor=1, family="family", step=10) tf.contrib.summary.audio(name="audio_v2", tensor=1, sample_rate=3, max_outputs=3, family="family", step=10) self.assertLen(summary.merge_all(), 8)
def test_monkey_patched_summaries_args(self, summary_maker): summary = summary_maker() before = _summaries() with monkey_patched_summaries(summary): for want, got in zip(before, _summaries()): self.assertNotEqual(want, got) tf.summary.scalar("scalar", 1, ["collection"], "family") tf.summary.image("image", 1, 3, ["collection"], "family") tf.summary.histogram("histogram", 1, ["collection"], "family") tf.summary.audio("audio", 1, 3, 3, ["collection"], "family") want_summary_fn_count = 4 try: tf.contrib.summary.scalar("scalar_v2", 1, "family", 10) tf.contrib.summary.image("image_v2", 1, True, 3, "family", 10) tf.contrib.summary.histogram("histogram_v2", 1, "family", 10) tf.contrib.summary.audio("audio_v2", 1, 3, 3, "family", 10) want_summary_fn_count += 4 except (AttributeError, ImportError): # TF 2.0 eliminates tf.contrib. pass self.assertEqual(before, _summaries()) self.assertLen(self._get_summary_ops(summary), want_summary_fn_count)