def testMergeCall(self): _assert_in_default_state(self) def merge_fn(dist, s): self.assertIs( distribution_strategy_context. _get_default_distribution_strategy(), dist) self.assertIs(None, distribution_strategy_context.get_replica_context()) self.assertIs( dist, distribution_strategy_context.get_cross_replica_context()) self.assertTrue( distribution_strategy_context.in_cross_replica_context()) self.assertIs( dist, distribution_strategy_context.get_distribution_strategy()) self.assertFalse( distribution_strategy_context.has_distribution_strategy()) return "foo_" + s replica_ctx = distribution_strategy_context.get_replica_context() self.assertIs( distribution_strategy_context._get_default_replica_context(), replica_ctx) self.assertEqual("foo_bar", replica_ctx.merge_call(merge_fn, args=("bar", ))) _assert_in_default_state(self)
def _assert_in_default_state(t): t.assertIs(distribution_strategy_context._get_default_replica_context(), distribution_strategy_context.get_replica_context()) t.assertIs(None, distribution_strategy_context.get_cross_replica_context()) t.assertIs(distribution_strategy_context._get_default_distribution_strategy(), distribution_strategy_context.get_distribution_strategy()) t.assertFalse(distribution_strategy_context.has_distribution_strategy())
def testMergeCall(self): _assert_in_default_state(self) def merge_fn(dist, s): self.assertIs( distribution_strategy_context._get_default_distribution_strategy(), dist) self.assertIs(None, distribution_strategy_context.get_replica_context()) self.assertIs(dist, distribution_strategy_context.get_cross_replica_context()) self.assertIs(dist, distribution_strategy_context.get_distribution_strategy()) self.assertFalse( distribution_strategy_context.has_distribution_strategy()) return "foo_" + s replica_ctx = distribution_strategy_context.get_replica_context() self.assertIs(distribution_strategy_context._get_default_replica_context(), replica_ctx) self.assertEqual("foo_bar", replica_ctx.merge_call(merge_fn, "bar")) _assert_in_default_state(self)