def testComputeAverageLossInvalidSampleWeights(self): with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError), (r"Incompatible shapes: \[3\] vs. \[2\]|" "Dimensions must be equal")): nn_impl.compute_average_loss([2.5, 6.2, 5.], sample_weight=[0.2, 0.8], global_batch_size=10)
def testComputeAverageLossInCrossReplicaContext(self, distribution): with distribution.scope(): with self.assertRaisesRegex( RuntimeError, "You are calling `compute_average_loss` in cross replica context" ): nn_impl.compute_average_loss([2, 3])
def testComputeAverageLossGlobalBatchSize_BatchSizeNegative(self): per_example_loss = [1, 2, 3, 4, 5] with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, "global_batch_size must be positive"): nn_impl.compute_average_loss(per_example_loss, global_batch_size=-1)
def testComputeAverageLossInvalidRank(self): per_example_loss = constant_op.constant(2) # Static rank with self.assertRaisesRegex( ValueError, "Invalid value passed for `per_example_loss`. " "Expected a tensor with at least rank 1."): nn_impl.compute_average_loss(per_example_loss) with context.graph_mode(): # Dynamic rank per_example_loss = array_ops.placeholder(dtype=dtypes.float32) loss = nn_impl.compute_average_loss(per_example_loss) with self.cached_session() as sess: with self.assertRaisesRegex( errors.InvalidArgumentError, "Invalid value passed for `per_example_loss`. " "Expected a tensor with at least rank 1."): sess.run(loss, {per_example_loss: 2})
def testComputeAverageLossDefaultGlobalBatchSize(self, distribution): # Without strategy - num replicas = 1 per_example_loss = constant_op.constant([2.5, 6.2, 5.]) loss = nn_impl.compute_average_loss(per_example_loss) self.assertAllClose(self.evaluate(loss), (2.5 + 6.2 + 5.) / 3) # With strategy - num replicas = 2 with distribution.scope(): per_replica_losses = distribution.run( nn_impl.compute_average_loss, args=(per_example_loss,)) loss = distribution.reduce("SUM", per_replica_losses, axis=None) self.assertAllClose(self.evaluate(loss), (2.5 + 6.2 + 5.) / 3)
def testComputeAverageLossGlobalBatchSize_BatchSizeFloat(self): per_example_loss = [1, 2, 3, 4, 5] with self.assertRaisesWithPredicateMatch( TypeError, "global_batch_size must be an int"): nn_impl.compute_average_loss(per_example_loss, global_batch_size=10.0)
def testComputeAverageLossGlobalBatchSize_BatchSizeNonScalar(self): per_example_loss = [1, 2, 3, 4, 5] with self.assertRaisesWithPredicateMatch( ValueError, "global_batch_size must be scalar"): nn_impl.compute_average_loss(per_example_loss, global_batch_size=[10])
def testComputeAverageLossGlobalBatchSize(self): per_example_loss = [1, 2, 3, 4, 5] loss = nn_impl.compute_average_loss(per_example_loss, global_batch_size=10) self.assertEqual(self.evaluate(loss), 1.5)
def testComputeAverageLossInvalidSampleWeights(self): with self.assertRaisesIncompatibleShapesError( (ValueError, errors_impl.InvalidArgumentError)): nn_impl.compute_average_loss([2.5, 6.2, 5.], sample_weight=[0.2, 0.8], global_batch_size=10)
def testComputeAverageLossInvalidSampleWeights(self): with self.assertRaisesRegex(ValueError, "weights can not be broadcast to values"): nn_impl.compute_average_loss([2.5, 6.2, 5.], sample_weight=[0.2, 0.8], global_batch_size=10)