def testComputeAverageLossInvalidSampleWeights(self):
   with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
                               (r"Incompatible shapes: \[3\] vs. \[2\]|"
                                "Dimensions must be equal")):
     nn_impl.compute_average_loss([2.5, 6.2, 5.],
                                  sample_weight=[0.2, 0.8],
                                  global_batch_size=10)
 def testComputeAverageLossInCrossReplicaContext(self, distribution):
     with distribution.scope():
         with self.assertRaisesRegex(
                 RuntimeError,
                 "You are calling `compute_average_loss` in cross replica context"
         ):
             nn_impl.compute_average_loss([2, 3])
Пример #3
0
 def testComputeAverageLossGlobalBatchSize_BatchSizeNegative(self):
     per_example_loss = [1, 2, 3, 4, 5]
     with self.assertRaisesWithPredicateMatch(
             errors_impl.InvalidArgumentError,
             "global_batch_size must be positive"):
         nn_impl.compute_average_loss(per_example_loss,
                                      global_batch_size=-1)
  def testComputeAverageLossInvalidRank(self):
    per_example_loss = constant_op.constant(2)

    # Static rank
    with self.assertRaisesRegex(
        ValueError, "Invalid value passed for `per_example_loss`. "
        "Expected a tensor with at least rank 1."):
      nn_impl.compute_average_loss(per_example_loss)

    with context.graph_mode():
      # Dynamic rank
      per_example_loss = array_ops.placeholder(dtype=dtypes.float32)
      loss = nn_impl.compute_average_loss(per_example_loss)

      with self.cached_session() as sess:
        with self.assertRaisesRegex(
            errors.InvalidArgumentError,
            "Invalid value passed for `per_example_loss`. "
            "Expected a tensor with at least rank 1."):
          sess.run(loss, {per_example_loss: 2})
  def testComputeAverageLossDefaultGlobalBatchSize(self, distribution):
    # Without strategy - num replicas = 1
    per_example_loss = constant_op.constant([2.5, 6.2, 5.])
    loss = nn_impl.compute_average_loss(per_example_loss)
    self.assertAllClose(self.evaluate(loss), (2.5 + 6.2 + 5.) / 3)

    # With strategy - num replicas = 2
    with distribution.scope():
      per_replica_losses = distribution.run(
          nn_impl.compute_average_loss, args=(per_example_loss,))
      loss = distribution.reduce("SUM", per_replica_losses, axis=None)
      self.assertAllClose(self.evaluate(loss), (2.5 + 6.2 + 5.) / 3)
 def testComputeAverageLossGlobalBatchSize_BatchSizeFloat(self):
   per_example_loss = [1, 2, 3, 4, 5]
   with self.assertRaisesWithPredicateMatch(
       TypeError, "global_batch_size must be an int"):
     nn_impl.compute_average_loss(per_example_loss, global_batch_size=10.0)
 def testComputeAverageLossGlobalBatchSize_BatchSizeNonScalar(self):
   per_example_loss = [1, 2, 3, 4, 5]
   with self.assertRaisesWithPredicateMatch(
       ValueError, "global_batch_size must be scalar"):
     nn_impl.compute_average_loss(per_example_loss, global_batch_size=[10])
 def testComputeAverageLossGlobalBatchSize(self):
   per_example_loss = [1, 2, 3, 4, 5]
   loss = nn_impl.compute_average_loss(per_example_loss, global_batch_size=10)
   self.assertEqual(self.evaluate(loss), 1.5)
 def testComputeAverageLossInvalidSampleWeights(self):
   with self.assertRaisesIncompatibleShapesError(
       (ValueError, errors_impl.InvalidArgumentError)):
     nn_impl.compute_average_loss([2.5, 6.2, 5.],
                                  sample_weight=[0.2, 0.8],
                                  global_batch_size=10)
 def testComputeAverageLossInvalidSampleWeights(self):
     with self.assertRaisesRegex(ValueError,
                                 "weights can not be broadcast to values"):
         nn_impl.compute_average_loss([2.5, 6.2, 5.],
                                      sample_weight=[0.2, 0.8],
                                      global_batch_size=10)