def _testCompareWithNN(self, weights, biases, partition_strategy): with ops.Graph().as_default(): loss = sampling_ops.rank_sampled_softmax_loss( weights=weights(), biases=biases(), labels=self._labels(), inputs=self._inputs(), num_sampled=self._num_sampled, num_resampled=self._num_resampled, num_classes=self._num_classes, num_true=self._num_true, sampled_values=self._sampled_values, resampling_temperature=1., remove_accidental_hits=self._remove_accidental_hits, partition_strategy=partition_strategy) loss_nn = nn.sampled_softmax_loss( weights=weights(), biases=biases(), labels=self._labels(), inputs=self._inputs(), num_sampled=self._num_resampled, num_classes=self._num_classes, num_true=self._num_true, sampled_values=self._resampled_values, remove_accidental_hits=self._remove_accidental_hits, partition_strategy=partition_strategy) with self.cached_session() as sess: loss_val = sess.run(loss) loss_nn_val = sess.run(loss_nn) self.assertAllClose(loss_val, loss_nn_val)
def testMissingPartitionStrategy(self): with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, r'unsupported partition_strategy \(None\)'): sampling_ops.rank_sampled_softmax_loss( weights=self._weights(), biases=self._biases(), labels=self._labels(), inputs=self._inputs(), num_sampled=2, num_resampled=1, num_classes=self._num_classes, num_true=self._num_true, sampled_values=None, resampling_temperature=1., remove_accidental_hits=True, partition_strategy=None)
def testInvalidNumSampled1(self): with ops.Graph().as_default(): with self.assertRaisesRegexp( ValueError, r'num_resampled \(3\) must be less than num_sampled \(2\)'): sampling_ops.rank_sampled_softmax_loss( weights=self._weights(), biases=self._biases(), labels=self._labels(), inputs=self._inputs(), num_sampled=2, num_resampled=3, num_classes=self._num_classes, num_true=self._num_true, sampled_values=None, resampling_temperature=1., remove_accidental_hits=True, partition_strategy='div')
def _testCompareWithNNTemperature(self, temperature, resampled): weights = [[1., 2.], [3., 4.]] # two sampled classes inputs = [[6., -5. / 2.], [-11., 21. / 2.]] # Let w0, w1 = weights of sampled classes (biases set to 0 for simplicity) # Let x0, x1 = inputs # logits: # w0.x0 = 1 # w0.x1 = 10 # w1.x0 = 8 # w1.x1 = 9 # Resampling 1 class with temperature = t will pick the larger of: # exp(1/t) + exp(10/t) ==> w0, for values of t < 2.12 # exp(8/t) + exp(9/t) ==> w1, for values of t > 2.13 num_sampled = 2 num_resampled = 1 num_classes = 2 num_true = 1 sampled_values = [0, 1], [[1.], [1.]], [1., 1.] resampled_values = [resampled], [[1.], [1.]], [1.] remove_accidental_hits = False with ops.Graph().as_default(): weights = constant_op.constant(weights) biases = constant_op.constant([0., 0.]) labels = constant_op.constant([[0], [1]], dtype=dtypes.int64) inputs = constant_op.constant(inputs) loss = sampling_ops.rank_sampled_softmax_loss( weights=weights, biases=biases, labels=labels, inputs=inputs, num_sampled=num_sampled, num_resampled=num_resampled, num_classes=num_classes, num_true=num_true, sampled_values=sampled_values, resampling_temperature=constant_op.constant(temperature), remove_accidental_hits=remove_accidental_hits, partition_strategy='div') loss_nn = nn.sampled_softmax_loss( weights=weights, biases=biases, labels=labels, inputs=inputs, num_sampled=num_resampled, num_classes=num_classes, num_true=num_true, sampled_values=resampled_values, remove_accidental_hits=remove_accidental_hits, partition_strategy='div') with self.cached_session() as sess: loss_val = sess.run(loss) loss_nn_val = sess.run(loss_nn) self.assertAllClose(loss_val, loss_nn_val)