コード例 #1
0
  def _testCompareWithNN(self, weights, biases, partition_strategy):
    with ops.Graph().as_default():
      loss = sampling_ops.rank_sampled_softmax_loss(
          weights=weights(),
          biases=biases(),
          labels=self._labels(),
          inputs=self._inputs(),
          num_sampled=self._num_sampled,
          num_resampled=self._num_resampled,
          num_classes=self._num_classes,
          num_true=self._num_true,
          sampled_values=self._sampled_values,
          resampling_temperature=1.,
          remove_accidental_hits=self._remove_accidental_hits,
          partition_strategy=partition_strategy)
      loss_nn = nn.sampled_softmax_loss(
          weights=weights(),
          biases=biases(),
          labels=self._labels(),
          inputs=self._inputs(),
          num_sampled=self._num_resampled,
          num_classes=self._num_classes,
          num_true=self._num_true,
          sampled_values=self._resampled_values,
          remove_accidental_hits=self._remove_accidental_hits,
          partition_strategy=partition_strategy)
      with self.cached_session() as sess:
        loss_val = sess.run(loss)
        loss_nn_val = sess.run(loss_nn)

    self.assertAllClose(loss_val, loss_nn_val)
コード例 #2
0
 def testMissingPartitionStrategy(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(ValueError,
                                  r'unsupported partition_strategy \(None\)'):
       sampling_ops.rank_sampled_softmax_loss(
           weights=self._weights(),
           biases=self._biases(),
           labels=self._labels(),
           inputs=self._inputs(),
           num_sampled=2,
           num_resampled=1,
           num_classes=self._num_classes,
           num_true=self._num_true,
           sampled_values=None,
           resampling_temperature=1.,
           remove_accidental_hits=True,
           partition_strategy=None)
コード例 #3
0
 def testInvalidNumSampled1(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         ValueError,
         r'num_resampled \(3\) must be less than num_sampled \(2\)'):
       sampling_ops.rank_sampled_softmax_loss(
           weights=self._weights(),
           biases=self._biases(),
           labels=self._labels(),
           inputs=self._inputs(),
           num_sampled=2,
           num_resampled=3,
           num_classes=self._num_classes,
           num_true=self._num_true,
           sampled_values=None,
           resampling_temperature=1.,
           remove_accidental_hits=True,
           partition_strategy='div')
コード例 #4
0
  def _testCompareWithNNTemperature(self, temperature, resampled):
    weights = [[1., 2.], [3., 4.]]  # two sampled classes
    inputs = [[6., -5. / 2.], [-11., 21. / 2.]]
    # Let w0, w1 = weights of sampled classes (biases set to 0 for simplicity)
    # Let x0, x1 = inputs
    # logits:
    #   w0.x0 = 1
    #   w0.x1 = 10
    #   w1.x0 = 8
    #   w1.x1 = 9
    # Resampling 1 class with temperature = t will pick the larger of:
    #   exp(1/t) + exp(10/t)  ==> w0, for values of t < 2.12
    #   exp(8/t) + exp(9/t)   ==> w1, for values of t > 2.13
    num_sampled = 2
    num_resampled = 1
    num_classes = 2
    num_true = 1
    sampled_values = [0, 1], [[1.], [1.]], [1., 1.]
    resampled_values = [resampled], [[1.], [1.]], [1.]
    remove_accidental_hits = False
    with ops.Graph().as_default():
      weights = constant_op.constant(weights)
      biases = constant_op.constant([0., 0.])
      labels = constant_op.constant([[0], [1]], dtype=dtypes.int64)
      inputs = constant_op.constant(inputs)
      loss = sampling_ops.rank_sampled_softmax_loss(
          weights=weights,
          biases=biases,
          labels=labels,
          inputs=inputs,
          num_sampled=num_sampled,
          num_resampled=num_resampled,
          num_classes=num_classes,
          num_true=num_true,
          sampled_values=sampled_values,
          resampling_temperature=constant_op.constant(temperature),
          remove_accidental_hits=remove_accidental_hits,
          partition_strategy='div')
      loss_nn = nn.sampled_softmax_loss(
          weights=weights,
          biases=biases,
          labels=labels,
          inputs=inputs,
          num_sampled=num_resampled,
          num_classes=num_classes,
          num_true=num_true,
          sampled_values=resampled_values,
          remove_accidental_hits=remove_accidental_hits,
          partition_strategy='div')
      with self.cached_session() as sess:
        loss_val = sess.run(loss)
        loss_nn_val = sess.run(loss_nn)

    self.assertAllClose(loss_val, loss_nn_val)