Ejemplo n.º 1
0
    def testMergeCall(self):
        _assert_in_default_state(self)

        def merge_fn(dist, s):
            self.assertIs(
                distribution_strategy_context.
                _get_default_distribution_strategy(), dist)
            self.assertIs(None,
                          distribution_strategy_context.get_replica_context())
            self.assertIs(
                dist,
                distribution_strategy_context.get_cross_replica_context())
            self.assertTrue(
                distribution_strategy_context.in_cross_replica_context())
            self.assertIs(
                dist,
                distribution_strategy_context.get_distribution_strategy())
            self.assertFalse(
                distribution_strategy_context.has_distribution_strategy())
            return "foo_" + s

        replica_ctx = distribution_strategy_context.get_replica_context()
        self.assertIs(
            distribution_strategy_context._get_default_replica_context(),
            replica_ctx)
        self.assertEqual("foo_bar",
                         replica_ctx.merge_call(merge_fn, args=("bar", )))
        _assert_in_default_state(self)
Ejemplo n.º 2
0
def _assert_in_default_state(t):
  t.assertIs(distribution_strategy_context._get_default_replica_context(),
             distribution_strategy_context.get_replica_context())
  t.assertIs(None, distribution_strategy_context.get_cross_replica_context())
  t.assertIs(distribution_strategy_context._get_default_distribution_strategy(),
             distribution_strategy_context.get_distribution_strategy())
  t.assertFalse(distribution_strategy_context.has_distribution_strategy())
Ejemplo n.º 3
0
def _assert_in_default_state(t):
  t.assertIs(distribution_strategy_context._get_default_replica_context(),
             distribution_strategy_context.get_replica_context())
  t.assertIs(None, distribution_strategy_context.get_cross_replica_context())
  t.assertIs(distribution_strategy_context._get_default_distribution_strategy(),
             distribution_strategy_context.get_distribution_strategy())
  t.assertFalse(distribution_strategy_context.has_distribution_strategy())
Ejemplo n.º 4
0
  def testMergeCall(self):
    _assert_in_default_state(self)

    def merge_fn(dist, s):
      self.assertIs(
          distribution_strategy_context._get_default_distribution_strategy(),
          dist)
      self.assertIs(None, distribution_strategy_context.get_replica_context())
      self.assertIs(dist,
                    distribution_strategy_context.get_cross_replica_context())
      self.assertIs(dist,
                    distribution_strategy_context.get_distribution_strategy())
      self.assertFalse(
          distribution_strategy_context.has_distribution_strategy())
      return "foo_" + s

    replica_ctx = distribution_strategy_context.get_replica_context()
    self.assertIs(distribution_strategy_context._get_default_replica_context(),
                  replica_ctx)
    self.assertEqual("foo_bar", replica_ctx.merge_call(merge_fn, "bar"))
    _assert_in_default_state(self)