def testScaleRegularizationLossInCrossReplicaContext(self, distribution):
     with distribution.scope():
         with self.assertRaisesRegex(
                 RuntimeError,
                 "You are calling `scale_regularization_loss` in "
                 "cross replica context"):
             nn_impl.scale_regularization_loss([2, 3])
  def testScaleRegularizationLoss(self, distribution):
    # Without strategy - num replicas = 1
    reg_losses = constant_op.constant([2.5, 6.2, 5.])
    loss = nn_impl.scale_regularization_loss(reg_losses)
    self.assertAllClose(self.evaluate(loss), (2.5 + 6.2 + 5.))

    # With strategy - num replicas = 2
    with distribution.scope():
      per_replica_losses = distribution.run(
          nn_impl.scale_regularization_loss, args=(reg_losses,))
      loss = distribution.reduce("SUM", per_replica_losses, axis=None)
      self.assertAllClose(self.evaluate(loss), (2.5 + 6.2 + 5.))