示例#1
0
def _monkey_patch_context(iteration_step_scope, scoped_summary,
                          trainable_vars):
    """Monkey-patches global attributes with subnetwork-specifics ones."""

    old_get_global_step_fn = tf_compat.v1.train.get_global_step
    old_get_or_create_global_step_fn = tf_compat.v1.train.get_or_create_global_step
    old_trainable_vars = tf_compat.v1.trainable_variables()

    def iteration_step(graph=None):
        graph = graph or tf_compat.v1.get_default_graph()
        with graph.as_default() as g, g.name_scope(None):
            with tf_compat.v1.variable_scope(iteration_step_scope,
                                             reuse=tf_compat.v1.AUTO_REUSE):
                return tf_compat.v1.get_variable(
                    "iteration_step",
                    shape=[],
                    initializer=tf_compat.v1.zeros_initializer(),
                    trainable=False,
                    dtype=tf.int64)

    # monkey-patch global attributes.
    setattr(tf_compat.v1.train, "get_global_step", iteration_step)
    setattr(tf_compat.v1.train, "get_or_create_global_step", iteration_step)
    setattr(tf_v1.train, "get_global_step", iteration_step)
    setattr(tf_v1.train, "get_or_create_global_step", iteration_step)
    setattr(tf.train, "get_global_step", iteration_step)
    setattr(tf.train, "get_or_create_global_step", iteration_step)
    setattr(train, "get_global_step", iteration_step)
    setattr(training_util, "get_global_step", iteration_step)
    setattr(train, "get_or_create_global_step", iteration_step)
    setattr(training_util, "get_or_create_global_step", iteration_step)
    # The TPUEmbedding uses dummy variables to coordinate sending and receiving
    # gradients. If no gradients are computed on these dummy variables, the
    # TPUEmbedding will throw an error.
    embedding_variables = tf_compat.v1.get_collection(
        "tpu_embedding_dummy_table_variables")
    _set_trainable_variables(trainable_vars + embedding_variables)

    try:
        with monkey_patched_summaries(scoped_summary):
            yield
    finally:
        # Revert monkey-patches.
        new_trainable_vars = _new_trainable_variables(trainable_vars)
        _set_trainable_variables(old_trainable_vars + new_trainable_vars)
        setattr(training_util, "get_or_create_global_step",
                old_get_or_create_global_step_fn)
        setattr(train, "get_or_create_global_step",
                old_get_or_create_global_step_fn)
        setattr(training_util, "get_global_step", old_get_global_step_fn)
        setattr(train, "get_global_step", old_get_global_step_fn)
        setattr(tf.train, "get_or_create_global_step",
                old_get_or_create_global_step_fn)
        setattr(tf.train, "get_global_step", old_get_global_step_fn)
        setattr(tf_v1.train, "get_or_create_global_step",
                old_get_or_create_global_step_fn)
        setattr(tf_v1.train, "get_global_step", old_get_global_step_fn)
        setattr(tf_compat.v1.train, "get_or_create_global_step",
                old_get_or_create_global_step_fn)
        setattr(tf_compat.v1.train, "get_global_step", old_get_global_step_fn)
示例#2
0
def _subnetwork_context(iteration_step_scope, scoped_summary):
    """Monkey-patches global attributes with subnetwork-specifics ones."""

    old_get_global_step_fn = tf.train.get_global_step
    old_get_or_create_global_step_fn = tf.train.get_or_create_global_step

    def iteration_step(graph=None):
        del graph
        with tf.variable_scope(iteration_step_scope, reuse=tf.AUTO_REUSE):
            return tf.get_variable("iteration_step",
                                   shape=[],
                                   initializer=tf.zeros_initializer(),
                                   trainable=False,
                                   dtype=tf.int64)

    # monkey-patch global attributes.
    tf.train.get_global_step = iteration_step
    tf.train.get_or_create_global_step = iteration_step
    training_util.get_global_step = iteration_step
    training_util.get_or_create_global_step = iteration_step

    try:
        with monkey_patched_summaries(scoped_summary):
            yield
    finally:
        # Revert monkey-patches.
        training_util.get_or_create_global_step = old_get_or_create_global_step_fn
        training_util.get_global_step = old_get_global_step_fn
        tf.train.get_or_create_global_step = old_get_or_create_global_step_fn
        tf.train.get_global_step = old_get_global_step_fn
示例#3
0
    def test_monkey_patched_summaries_kwargs(self, summary_maker):
        summary = summary_maker()
        before = _summaries()
        with monkey_patched_summaries(summary):
            for want, got in zip(before, _summaries()):
                self.assertNotEqual(want, got)
            tf.summary.scalar(name="scalar",
                              tensor=1,
                              collections=["collection"],
                              family="family")
            tf.summary.image(name="image",
                             tensor=1,
                             max_outputs=3,
                             collections=["collection"],
                             family="family")
            tf.summary.histogram(name="histogram",
                                 values=1,
                                 collections=["collection"],
                                 family="family")
            tf.summary.audio(name="audio",
                             tensor=1,
                             sample_rate=3,
                             max_outputs=3,
                             collections=["collection"],
                             family="family")

            want_summary_fn_count = 4
            try:
                tf.contrib.summary.scalar(name="scalar_v2",
                                          tensor=1,
                                          family="family",
                                          step=10)
                tf.contrib.summary.image(name="image_v2",
                                         tensor=1,
                                         bad_color=True,
                                         max_images=3,
                                         family="family",
                                         step=10)
                tf.contrib.summary.histogram(name="histogram_v2",
                                             tensor=1,
                                             family="family",
                                             step=10)
                tf.contrib.summary.audio(name="audio_v2",
                                         tensor=1,
                                         sample_rate=3,
                                         max_outputs=3,
                                         family="family",
                                         step=10)
                want_summary_fn_count += 4
            except (AttributeError, ImportError):
                # TF 2.0 eliminates tf.contrib.
                pass
        self.assertEqual(before, _summaries())
        self.assertLen(self._get_summary_ops(summary), want_summary_fn_count)
示例#4
0
    def test_monkey_patched_summaries_args(self, summary_maker):
        summary = summary_maker()
        with monkey_patched_summaries(summary):
            tf.summary.scalar("scalar", 1, ["collection"], "family")
            tf.summary.image("image", 1, 3, ["collection"], "family")
            tf.summary.histogram("histogram", 1, ["collection"], "family")
            tf.summary.audio("audio", 1, 3, 3, ["collection"], "family")

            tf.contrib.summary.scalar("scalar_v2", 1, "family", 10)
            tf.contrib.summary.image("image_v2", 1, True, 3, "family", 10)
            tf.contrib.summary.histogram("histogram_v2", 1, "family", 10)
            tf.contrib.summary.audio("audio_v2", 1, 3, 3, "family", 10)
        self.assertLen(summary.merge_all(), 8)
示例#5
0
    def test_monkey_patched_summaries_args(self):
        summary = _ScopedSummary(self.test_subdirectory, global_step=10)
        with monkey_patched_summaries(summary):
            tf.summary.scalar("scalar", 1, ["collection"], "family")
            tf.summary.image("image", 1, 3, ["collection"], "family")
            tf.summary.histogram("histogram", 1, ["collection"], "family")
            tf.summary.audio("audio", 1, 3, 3, ["collection"], "family")

            tf.contrib.summary.scalar("scalar_v2", 1, "family", 10)
            tf.contrib.summary.image("image_v2", 1, True, 3, "family", 10)
            tf.contrib.summary.histogram("histogram_v2", 1, "family", 10)
            tf.contrib.summary.audio("audio_v2", 1, 3, 3, "family", 10)
        self.assertLen(summary.merge_all(), 8)
示例#6
0
def _monkey_patch_context(iteration_step_scope, scoped_summary,
                          trainable_vars):
    """Monkey-patches global attributes with subnetwork-specifics ones."""

    old_get_global_step_fn = tf_compat.v1.train.get_global_step
    old_get_or_create_global_step_fn = tf_compat.v1.train.get_or_create_global_step
    old_trainable_vars = tf_compat.v1.trainable_variables()

    def iteration_step(graph=None):
        graph = graph or tf_compat.v1.get_default_graph()
        with graph.as_default() as g, g.name_scope(None):
            with tf_compat.v1.variable_scope(iteration_step_scope,
                                             reuse=tf_compat.v1.AUTO_REUSE):
                return tf_compat.v1.get_variable(
                    "iteration_step",
                    shape=[],
                    initializer=tf_compat.v1.zeros_initializer(),
                    trainable=False,
                    dtype=tf.int64)

    # monkey-patch global attributes.
    setattr(tf_compat.v1.train, "get_global_step", iteration_step)
    setattr(tf_compat.v1.train, "get_or_create_global_step", iteration_step)
    setattr(tf.train, "get_global_step", iteration_step)
    setattr(tf.train, "get_or_create_global_step", iteration_step)
    setattr(train, "get_global_step", iteration_step)
    setattr(training_util, "get_global_step", iteration_step)
    setattr(train, "get_or_create_global_step", iteration_step)
    setattr(training_util, "get_or_create_global_step", iteration_step)
    _set_trainable_variables(trainable_vars)

    try:
        with monkey_patched_summaries(scoped_summary):
            yield
    finally:
        # Revert monkey-patches.
        new_trainable_vars = _new_trainable_variables(trainable_vars)
        _set_trainable_variables(old_trainable_vars + new_trainable_vars)
        setattr(training_util, "get_or_create_global_step",
                old_get_or_create_global_step_fn)
        setattr(train, "get_or_create_global_step",
                old_get_or_create_global_step_fn)
        setattr(training_util, "get_global_step", old_get_global_step_fn)
        setattr(train, "get_global_step", old_get_global_step_fn)
        setattr(tf.train, "get_or_create_global_step",
                old_get_or_create_global_step_fn)
        setattr(tf.train, "get_global_step", old_get_global_step_fn)
        setattr(tf_compat.v1.train, "get_or_create_global_step",
                old_get_or_create_global_step_fn)
        setattr(tf_compat.v1.train, "get_global_step", old_get_global_step_fn)
示例#7
0
    def test_monkey_patched_summaries_args(self, summary_maker):
        summary = summary_maker()
        before = _summaries()
        with monkey_patched_summaries(summary):
            self.assertNotEqual(before, _summaries())
            tf.summary.scalar("scalar", 1, ["collection"], "family")
            tf.summary.image("image", 1, 3, ["collection"], "family")
            tf.summary.histogram("histogram", 1, ["collection"], "family")
            tf.summary.audio("audio", 1, 3, 3, ["collection"], "family")

            tf.contrib.summary.scalar("scalar_v2", 1, "family", 10)
            tf.contrib.summary.image("image_v2", 1, True, 3, "family", 10)
            tf.contrib.summary.histogram("histogram_v2", 1, "family", 10)
            tf.contrib.summary.audio("audio_v2", 1, 3, 3, "family", 10)
        self.assertEqual(before, _summaries())
        self.assertLen(self._get_summary_ops(summary), 8)
示例#8
0
    def test_monkey_patched_summaries_kwargs(self, summary_maker):
        summary = summary_maker()
        before = _summaries()
        with monkey_patched_summaries(summary):
            self.assertNotEqual(before, _summaries())
            tf.summary.scalar(name="scalar",
                              tensor=1,
                              collections=["collection"],
                              family="family")
            tf.summary.image(name="image",
                             tensor=1,
                             max_outputs=3,
                             collections=["collection"],
                             family="family")
            tf.summary.histogram(name="histogram",
                                 values=1,
                                 collections=["collection"],
                                 family="family")
            tf.summary.audio(name="audio",
                             tensor=1,
                             sample_rate=3,
                             max_outputs=3,
                             collections=["collection"],
                             family="family")

            tf.contrib.summary.scalar(name="scalar_v2",
                                      tensor=1,
                                      family="family",
                                      step=10)
            tf.contrib.summary.image(name="image_v2",
                                     tensor=1,
                                     bad_color=True,
                                     max_images=3,
                                     family="family",
                                     step=10)
            tf.contrib.summary.histogram(name="histogram_v2",
                                         tensor=1,
                                         family="family",
                                         step=10)
            tf.contrib.summary.audio(name="audio_v2",
                                     tensor=1,
                                     sample_rate=3,
                                     max_outputs=3,
                                     family="family",
                                     step=10)
        self.assertEqual(before, _summaries())
        self.assertLen(self._get_summary_ops(summary), 8)
示例#9
0
    def test_monkey_patched_summaries_kwargs(self):
        summary = _ScopedSummary(self.test_subdirectory, global_step=10)
        with monkey_patched_summaries(summary):
            tf.summary.scalar(name="scalar",
                              tensor=1,
                              collections=["collection"],
                              family="family")
            tf.summary.image(name="image",
                             tensor=1,
                             max_outputs=3,
                             collections=["collection"],
                             family="family")
            tf.summary.histogram(name="histogram",
                                 values=1,
                                 collections=["collection"],
                                 family="family")
            tf.summary.audio(name="audio",
                             tensor=1,
                             sample_rate=3,
                             max_outputs=3,
                             collections=["collection"],
                             family="family")

            tf.contrib.summary.scalar(name="scalar_v2",
                                      tensor=1,
                                      family="family",
                                      step=10)
            tf.contrib.summary.image(name="image_v2",
                                     tensor=1,
                                     bad_color=True,
                                     max_images=3,
                                     family="family",
                                     step=10)
            tf.contrib.summary.histogram(name="histogram_v2",
                                         tensor=1,
                                         family="family",
                                         step=10)
            tf.contrib.summary.audio(name="audio_v2",
                                     tensor=1,
                                     sample_rate=3,
                                     max_outputs=3,
                                     family="family",
                                     step=10)
        self.assertLen(summary.merge_all(), 8)
  def test_monkey_patched_summaries_args(self, summary_maker):
    summary = summary_maker()
    before = _summaries()
    with monkey_patched_summaries(summary):
      for want, got in zip(before, _summaries()):
        self.assertNotEqual(want, got)
      tf.summary.scalar("scalar", 1, ["collection"], "family")
      tf.summary.image("image", 1, 3, ["collection"], "family")
      tf.summary.histogram("histogram", 1, ["collection"], "family")
      tf.summary.audio("audio", 1, 3, 3, ["collection"], "family")

      want_summary_fn_count = 4
      try:
        tf.contrib.summary.scalar("scalar_v2", 1, "family", 10)
        tf.contrib.summary.image("image_v2", 1, True, 3, "family", 10)
        tf.contrib.summary.histogram("histogram_v2", 1, "family", 10)
        tf.contrib.summary.audio("audio_v2", 1, 3, 3, "family", 10)
        want_summary_fn_count += 4
      except (AttributeError, ImportError):
        # TF 2.0 eliminates tf.contrib.
        pass
    self.assertEqual(before, _summaries())
    self.assertLen(self._get_summary_ops(summary), want_summary_fn_count)