Example #1
0
    def xent_loss(self, aff, neg_aff):
        if FLAGS.remove_accidental_hits:
            with tf.variable_scope('remove_accidental_hits'):
                neg_ids = tf.reshape(self.neg_id, [-1])
                input_ids = tf.reshape(self.src_id, [-1, 1])
                neg_aff = tf.reshape(neg_aff, [-1])
                acc_hits = candidate_sampling_ops.compute_accidental_hits(
                    input_ids, neg_ids, num_true=1)
                acc_indices, acc_ids, acc_weights = acc_hits
                if neg_aff.dtype != acc_weights.dtype:
                    acc_weights = math_ops.cast(acc_weights, neg_aff.dtype)

                acc_ids = math_ops.cast(acc_ids, dtypes.int32)
                tf.summary.scalar('accidental_hits_num', tf.shape(acc_ids)[0])
                mask_tensor = sparse_ops.sparse_to_dense(
                    acc_ids,
                    tf.shape(neg_aff),
                    acc_weights,
                    default_value=0.0,
                    validate_indices=False)
                neg_aff += mask_tensor

        if FLAGS.truncate_affinity:
            with tf.variable_scope('truncate_affinity'):
                embedding_size = tf.cast(
                    const.NODE_CONFIG[self.dst_type]['embedding_size'] / 2.0,
                    tf.float32)
                pos_num = tf.shape(aff)[0]
                aff = tf.gather_nd(
                    aff, tf.where(tf.less(tf.abs(aff), embedding_size)))
                tf.summary.scalar('truncate_pos_affinity',
                                  pos_num - tf.shape(aff)[0])

                neg_aff = tf.reshape(neg_aff, [-1])
                neg_num = tf.shape(neg_aff)[0]
                neg_aff = tf.gather_nd(
                    neg_aff, tf.where(tf.less(tf.abs(neg_aff),
                                              embedding_size)))

                tf.summary.scalar('truncate_neg_affinity',
                                  neg_num - tf.shape(neg_aff)[0])

        true_xent = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(aff), logits=aff)
        negative_xent = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.zeros_like(neg_aff), logits=neg_aff)

        loss = (tf.reduce_sum(true_xent) +
                tf.reduce_sum(negative_xent)) / tf.cast(
                    FLAGS.batch_size, tf.float32)

        tf.summary.scalar('loss_' + self.src_type + '2' + self.dst_type, loss)
        return loss
Example #2
0
  def testAccidentalHits(self):
    with self.cached_session() as sess:
      true_classes = constant_op.constant(
          [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
      sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler(
          true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
      accidental_hits = candidate_sampling_ops.compute_accidental_hits(
          true_classes, sampled_candidates, self.NUM_TRUE)
      indices, ids, weights = self.evaluate(accidental_hits)

    self.assertEqual(1, accidental_hits[0].get_shape().ndims)
    self.assertEqual(1, accidental_hits[1].get_shape().ndims)
    self.assertEqual(1, accidental_hits[2].get_shape().ndims)
    for index, id_, weight in zip(indices, ids, weights):
      self.assertTrue(id_ in self.TRUE_LABELS[index])
      self.assertLess(weight, -1.0e37)
Example #3
0
def _compute_sampled_logits(outfile,weights,biases,inputs,labels,num_sampled,num_classes,
                            num_true=1,sampled_values=None,subtract_log_q=True,remove_accidental_hits=False,partition_strategy="mod",name=None):
  if not isinstance(weights, list):
    weights = [weights]
  with ops.name_scope(name, "compute_sampled_logits",weights + [biases, inputs, labels]):
    if labels.dtype != dtypes.int64:
      labels = math_ops.cast(labels, dtypes.int64)
    labels_flat = array_ops.reshape(labels, [-1])
    if sampled_values is None:
      sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(true_classes=labels,num_true=num_true,num_sampled=num_sampled,unique=True,range_max=num_classes)
    sampled, true_expected_count, sampled_expected_count = sampled_values
    all_ids = array_ops.concat(0, [labels_flat, sampled])
    all_w = embedding_ops.embedding_lookup(outfile,weights, all_ids, partition_strategy=partition_strategy)
    all_b = embedding_ops.embedding_lookup(outfile,biases, all_ids)
    true_w = array_ops.slice(all_w, [0, 0], array_ops.pack([array_ops.shape(labels_flat)[0], -1]))
    true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
    dim = array_ops.shape(true_w)[1:2]
    new_true_w_shape = array_ops.concat(0, [[-1, num_true], dim])
    row_wise_dots = math_ops.mul(array_ops.expand_dims(inputs, 1),array_ops.reshape(true_w, new_true_w_shape))
    dots_as_matrix = array_ops.reshape(row_wise_dots,array_ops.concat(0, [[-1], dim]))
    true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
    true_b = array_ops.reshape(true_b, [-1, num_true])
    true_logits += true_b
    sampled_w = array_ops.slice(all_w, array_ops.pack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
    sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])
    sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True) + sampled_b
    if remove_accidental_hits:
      acc_hits = candidate_sampling_ops.compute_accidental_hits(labels, sampled, num_true=num_true)
      acc_indices, acc_ids, acc_weights = acc_hits
      acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
      acc_ids_2d_int32 = array_ops.reshape(math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
      sparse_indices = array_ops.concat(1, [acc_indices_2d, acc_ids_2d_int32],"sparse_indices")
      sampled_logits_shape = array_ops.concat(0,[array_ops.shape(labels)[:1], array_ops.expand_dims(num_sampled, 0)])
      if sampled_logits.dtype != acc_weights.dtype:
        acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
      sampled_logits += sparse_ops.sparse_to_dense(sparse_indices,sampled_logits_shape,acc_weights,default_value=0.0,validate_indices=False)
    if subtract_log_q:
      true_logits -= math_ops.log(true_expected_count)
      sampled_logits -= math_ops.log(sampled_expected_count)
    out_logits = array_ops.concat(1, [true_logits, sampled_logits])
    out_labels = array_ops.concat(1,[array_ops.ones_like(true_logits) / num_true,array_ops.zeros_like(sampled_logits)])
  return out_logits, out_labels
Example #4
0
def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
                            num_classes, num_true=1,
                            sampled_values=None,
                            subtract_log_q=True,
                            remove_accidental_hits=False,
                            partition_strategy="mod",
                            name=None):
  """Helper function for nce_loss and sampled_softmax_loss functions.

  Computes sampled output training logits and labels suitable for implementing
  e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see
  sampled_softmax_loss).

  Note: In the case where num_true > 1, we assign to each target class
  the target probability 1 / num_true so that the target probabilities
  sum to 1 per-example.

  Args:
    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
        objects whose concatenation along dimension 0 has shape
        `[num_classes, dim]`.  The (possibly-partitioned) class embeddings.
    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
        activations of the input network.
    labels: A `Tensor` of type `int64` and shape `[batch_size,
        num_true]`. The target classes.  Note that this format differs from
        the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
    num_sampled: An `int`.  The number of classes to randomly sample per batch.
    num_classes: An `int`. The number of possible classes.
    num_true: An `int`.  The number of target classes per training example.
    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
        (if None, we default to `log_uniform_candidate_sampler`)
    subtract_log_q: A `bool`.  whether to subtract the log expected count of
        the labels in the sample to get the logits of the true labels.
        Default is True.  Turn off for Negative Sampling.
    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
        where a sampled class equals one of the target classes.  Default is
        False.
    partition_strategy: A string specifying the partitioning strategy, relevant
        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
    name: A name for the operation (optional).
  Returns:
    out_logits, out_labels: `Tensor` objects each with shape
        `[batch_size, num_true + num_sampled]`, for passing to either
        `nn.sigmoid_cross_entropy_with_logits` (NCE) or
        `nn.softmax_cross_entropy_with_logits` (sampled softmax).
  """

  if not isinstance(weights, list):
    weights = [weights]

  with ops.op_scope(
      weights + [biases, inputs, labels], name, "compute_sampled_logits"):
    if labels.dtype != dtypes.int64:
      labels = math_ops.cast(labels, dtypes.int64)
    labels_flat = array_ops.reshape(labels, [-1])

    # Sample the negative labels.
    #   sampled shape: [num_sampled] tensor
    #   true_expected_count shape = [batch_size, 1] tensor
    #   sampled_expected_count shape = [num_sampled] tensor
    if sampled_values is None:
      sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
          true_classes=labels,
          num_true=num_true,
          num_sampled=num_sampled,
          unique=True,
          range_max=num_classes)
    # NOTE: pylint cannot tell that 'sampled_values' is a sequence
    # pylint: disable=unpacking-non-sequence
    sampled, true_expected_count, sampled_expected_count = sampled_values
    # pylint: enable=unpacking-non-sequence

    # labels_flat is a [batch_size * num_true] tensor
    # sampled is a [num_sampled] int tensor
    all_ids = array_ops.concat(0, [labels_flat, sampled])

    # weights shape is [num_classes, dim]
    all_w = embedding_ops.embedding_lookup(
        weights, all_ids, partition_strategy=partition_strategy)
    all_b = embedding_ops.embedding_lookup(biases, all_ids)
    # true_w shape is [batch_size * num_true, dim]
    # true_b is a [batch_size * num_true] tensor
    true_w = array_ops.slice(
        all_w, [0, 0], array_ops.pack([array_ops.shape(labels_flat)[0], -1]))
    true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))

    # inputs shape is [batch_size, dim]
    # true_w shape is [batch_size * num_true, dim]
    # row_wise_dots is [batch_size, num_true, dim]
    dim = array_ops.shape(true_w)[1:2]
    new_true_w_shape = array_ops.concat(0, [[-1, num_true], dim])
    row_wise_dots = math_ops.mul(
        array_ops.expand_dims(inputs, 1),
        array_ops.reshape(true_w, new_true_w_shape))
    # We want the row-wise dot plus biases which yields a
    # [batch_size, num_true] tensor of true_logits.
    dots_as_matrix = array_ops.reshape(row_wise_dots,
                                       array_ops.concat(0, [[-1], dim]))
    true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
    true_b = array_ops.reshape(true_b, [-1, num_true])
    true_logits += true_b

    # Lookup weights and biases for sampled labels.
    #   sampled_w shape is [num_sampled, dim]
    #   sampled_b is a [num_sampled] float tensor
    sampled_w = array_ops.slice(
        all_w, array_ops.pack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
    sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])

    # inputs has shape [batch_size, dim]
    # sampled_w has shape [num_sampled, dim]
    # sampled_b has shape [num_sampled]
    # Apply X*W'+B, which yields [batch_size, num_sampled]
    sampled_logits = math_ops.matmul(inputs,
                                     sampled_w,
                                     transpose_b=True) + sampled_b

    if remove_accidental_hits:
      acc_hits = candidate_sampling_ops.compute_accidental_hits(
          labels, sampled, num_true=num_true)
      acc_indices, acc_ids, acc_weights = acc_hits

      # This is how SparseToDense expects the indices.
      acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
      acc_ids_2d_int32 = array_ops.reshape(math_ops.cast(
          acc_ids, dtypes.int32), [-1, 1])
      sparse_indices = array_ops.concat(
          1, [acc_indices_2d, acc_ids_2d_int32], "sparse_indices")
      # Create sampled_logits_shape = [batch_size, num_sampled]
      sampled_logits_shape = array_ops.concat(
          0,
          [array_ops.shape(labels)[:1], array_ops.expand_dims(num_sampled, 0)])
      if sampled_logits.dtype != acc_weights.dtype:
        acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
      sampled_logits += sparse_ops.sparse_to_dense(
          sparse_indices, sampled_logits_shape, acc_weights,
          default_value=0.0, validate_indices=False)

    if subtract_log_q:
      # Subtract log of Q(l), prior probability that l appears in sampled.
      true_logits -= math_ops.log(true_expected_count)
      sampled_logits -= math_ops.log(sampled_expected_count)

    # Construct output logits and labels. The true labels/logits start at col 0.
    out_logits = array_ops.concat(1, [true_logits, sampled_logits])
    # true_logits is a float tensor, ones_like(true_logits) is a float tensor
    # of ones. We then divide by num_true to ensure the per-example labels sum
    # to 1.0, i.e. form a proper probability distribution.
    out_labels = array_ops.concat(
        1, [array_ops.ones_like(true_logits) / num_true,
            array_ops.zeros_like(sampled_logits)])

  return out_logits, out_labels
Example #5
0
def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
                            num_classes, num_true=1,
                            sampled_values=None,
                            subtract_log_q=True,
                            remove_accidental_hits=False,
                            partition_strategy="mod",
                            name=None):
  """Helper function for nce_loss and sampled_softmax_loss functions.

  Computes sampled output training logits and labels suitable for implementing
  e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see
  sampled_softmax_loss).

  Note: In the case where num_true > 1, we assign to each target class
  the target probability 1 / num_true so that the target probabilities
  sum to 1 per-example.

  Args:
    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
        objects whose concatenation along dimension 0 has shape
        `[num_classes, dim]`.  The (possibly-partitioned) class embeddings.
    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
        activations of the input network.
    labels: A `Tensor` of type `int64` and shape `[batch_size,
        num_true]`. The target classes.  Note that this format differs from
        the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
    num_sampled: An `int`.  The number of classes to randomly sample per batch.
    num_classes: An `int`. The number of possible classes.
    num_true: An `int`.  The number of target classes per training example.
    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
        (if None, we default to `log_uniform_candidate_sampler`)
    subtract_log_q: A `bool`.  whether to subtract the log expected count of
        the labels in the sample to get the logits of the true labels.
        Default is True.  Turn off for Negative Sampling.
    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
        where a sampled class equals one of the target classes.  Default is
        False.
    partition_strategy: A string specifying the partitioning strategy, relevant
        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
    name: A name for the operation (optional).
  Returns:
    out_logits, out_labels: `Tensor` objects each with shape
        `[batch_size, num_true + num_sampled]`, for passing to either
        `nn.sigmoid_cross_entropy_with_logits` (NCE) or
        `nn.softmax_cross_entropy_with_logits` (sampled softmax).
  """

  if not isinstance(weights, list):
    weights = [weights]

  with ops.op_scope(
      weights + [biases, inputs, labels], name, "compute_sampled_logits"):
    if labels.dtype != dtypes.int64:
      labels = math_ops.cast(labels, dtypes.int64)
    labels_flat = array_ops.reshape(labels, [-1])

    # Sample the negative labels.
    #   sampled shape: [num_sampled] tensor
    #   true_expected_count shape = [batch_size, 1] tensor
    #   sampled_expected_count shape = [num_sampled] tensor
    if sampled_values is None:
      sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
          true_classes=labels,
          num_true=num_true,
          num_sampled=num_sampled,
          unique=True,
          range_max=num_classes)
    # NOTE: pylint cannot tell that 'sampled_values' is a sequence
    # pylint: disable=unpacking-non-sequence
    sampled, true_expected_count, sampled_expected_count = sampled_values
    # pylint: enable=unpacking-non-sequence

    # labels_flat is a [batch_size * num_true] tensor
    # sampled is a [num_sampled] int tensor
    all_ids = array_ops.concat(0, [labels_flat, sampled])

    # weights shape is [num_classes, dim]
    all_w = embedding_ops.embedding_lookup(
        weights, all_ids, partition_strategy=partition_strategy)
    all_b = embedding_ops.embedding_lookup(biases, all_ids)
    # true_w shape is [batch_size * num_true, dim]
    # true_b is a [batch_size * num_true] tensor
    true_w = array_ops.slice(
        all_w, [0, 0], array_ops.pack([array_ops.shape(labels_flat)[0], -1]))
    true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))

    # inputs shape is [batch_size, dim]
    # true_w shape is [batch_size * num_true, dim]
    # row_wise_dots is [batch_size, num_true, dim]
    dim = array_ops.shape(true_w)[1:2]
    new_true_w_shape = array_ops.concat(0, [[-1, num_true], dim])
    row_wise_dots = math_ops.mul(
        array_ops.expand_dims(inputs, 1),
        array_ops.reshape(true_w, new_true_w_shape))
    # We want the row-wise dot plus biases which yields a
    # [batch_size, num_true] tensor of true_logits.
    dots_as_matrix = array_ops.reshape(row_wise_dots,
                                       array_ops.concat(0, [[-1], dim]))
    true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
    true_b = array_ops.reshape(true_b, [-1, num_true])
    true_logits += true_b

    # Lookup weights and biases for sampled labels.
    #   sampled_w shape is [num_sampled, dim]
    #   sampled_b is a [num_sampled] float tensor
    sampled_w = array_ops.slice(
        all_w, array_ops.pack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
    sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])

    # inputs has shape [batch_size, dim]
    # sampled_w has shape [num_sampled, dim]
    # sampled_b has shape [num_sampled]
    # Apply X*W'+B, which yields [batch_size, num_sampled]
    sampled_logits = math_ops.matmul(inputs,
                                     sampled_w,
                                     transpose_b=True) + sampled_b

    if remove_accidental_hits:
      acc_hits = candidate_sampling_ops.compute_accidental_hits(
          labels, sampled, num_true=num_true)
      acc_indices, acc_ids, acc_weights = acc_hits

      # This is how SparseToDense expects the indices.
      acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
      acc_ids_2d_int32 = array_ops.reshape(math_ops.cast(
          acc_ids, dtypes.int32), [-1, 1])
      sparse_indices = array_ops.concat(
          1, [acc_indices_2d, acc_ids_2d_int32], "sparse_indices")
      # Create sampled_logits_shape = [batch_size, num_sampled]
      sampled_logits_shape = array_ops.concat(
          0,
          [array_ops.shape(labels)[:1], array_ops.expand_dims(num_sampled, 0)])
      if sampled_logits.dtype != acc_weights.dtype:
        acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
      sampled_logits += sparse_ops.sparse_to_dense(
          sparse_indices, sampled_logits_shape, acc_weights,
          default_value=0.0, validate_indices=False)

    if subtract_log_q:
      # Subtract log of Q(l), prior probability that l appears in sampled.
      true_logits -= math_ops.log(true_expected_count)
      sampled_logits -= math_ops.log(sampled_expected_count)

    # Construct output logits and labels. The true labels/logits start at col 0.
    out_logits = array_ops.concat(1, [true_logits, sampled_logits])
    # true_logits is a float tensor, ones_like(true_logits) is a float tensor
    # of ones. We then divide by num_true to ensure the per-example labels sum
    # to 1.0, i.e. form a proper probability distribution.
    out_labels = array_ops.concat(
        1, [array_ops.ones_like(true_logits) / num_true,
            array_ops.zeros_like(sampled_logits)])

  return out_logits, out_labels
Example #6
0
def _compute_sampled_logits(weights,
                            biases,
                            inputs,
                            labels,
                            num_sampled,
                            num_classes,
                            num_true=1,
                            sampled_values=None,
                            subtract_log_q=True,
                            remove_accidental_hits=False,
                            name=None):
    """Helper function for nce_loss and sampled_softmax_loss functions.

  Computes sampled output training logits and labels suitable for implementing
  e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see
  sampled_softmax_loss).

  Note: In the case where num_true > 1, we assign to each target class
  the target probability 1 / num_true so that the target probabilities
  sum to 1 per-example.

  Args:
    weights: tensor of label embeddings with shape = [num_classes, dim]
    biases: tensor of num_classes label biases
    inputs: tensor with shape = [batch_size, dim] corresponding to forward
        activations of the input network
    labels: int tensor with shape [batch_size, num_true]
    num_sampled: number of label classes to sample per batch
    num_classes: number of possible label classes in the data (e.g. vocab size)
    num_true: number of target classes per example (default: 1)
    sampled_values: a tuple of (sampled_candidates, true_expected_count,
        sampled_expected_count) returned by a *CandidateSampler function to use
        (if None, we default to LogUniformCandidateSampler)
    subtract_log_q: subtract the log expected count of the labels in the sample
        to get the logits of the true labels (default: True)
        Turn off for Negative Sampling.
    remove_accidental_hits: whether to remove "accidental hits" where a sampled
        label equals the true labels (bool, default: False)
    name: name for this op

  Returns:
    out_logits, out_labels: tensors with shape [batch_size, num_true +
        num_sampled] for passing to either SigmoidCrossEntropyWithLogits (NCE)
        or SoftmaxCrossEntropyWithLogits (sampled softmax).

  """

    with ops.op_scope([weights, biases, inputs, labels], name,
                      "compute_sampled_logits"):
        if labels.dtype != types.int64:
            labels = math_ops.cast(labels, types.int64)
        labels_flat = array_ops.reshape(labels, [-1])

        # Sample the negative labels.
        #   sampled shape: num_sampled vector
        #   true_expected_count shape = [batch_size, 1]
        #   sampled_expected_count shape = num_sampled vector
        if sampled_values is None:
            sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
                true_classes=labels,
                num_true=num_true,
                num_sampled=num_sampled,
                unique=True,
                range_max=num_classes)
        # NOTE: pylint cannot tell that 'sampled_values' is a sequence
        # pylint: disable=unpacking-non-sequence
        sampled, true_expected_count, sampled_expected_count = sampled_values
        # pylint: enable=unpacking-non-sequence

        # weights shape is [num_classes, dim]
        # labels_flat is a [batch_size * num_true] vector
        # true_w shape is [batch_size * num_true, dim]
        # true_b is a [batch_size * num_true] vector
        true_w = embedding_ops.embedding_lookup(weights, labels_flat)
        true_b = embedding_ops.embedding_lookup(biases, labels_flat)

        # inputs shape is [batch_size, dim]
        # true_w shape is [batch_size * num_true, dim]
        # row_wise_dots is [batch_size, num_true, dim]
        dim = array_ops.shape(true_w)[1:2]
        new_true_w_shape = array_ops.concat(0, [[-1, num_true], dim])
        row_wise_dots = math_ops.mul(
            array_ops.expand_dims(inputs, 1),
            array_ops.reshape(true_w, new_true_w_shape))
        # We want the row-wise dot plus biases which yields a
        # [batch_size, num_true] tensor of true_logits.
        dots_as_matrix = array_ops.reshape(row_wise_dots,
                                           array_ops.concat(0, [[-1], dim]))
        true_logits = array_ops.reshape(_sum_rows(dots_as_matrix),
                                        [-1, num_true])
        true_b = array_ops.reshape(true_b, [-1, num_true])
        true_logits += true_b

        # Lookup weights and biases for sampled labels.
        #   sampled is a num_sampled int vector
        #   sampled_w shape is [num_sampled, dim]
        #   sampled_b is a num_sampled float vector
        sampled_w = embedding_ops.embedding_lookup(weights, sampled)
        sampled_b = embedding_ops.embedding_lookup(biases, sampled)

        # inputs has shape [batch_size, dim]
        # sampled_w has shape [num_sampled, dim]
        # sampled_b has shape [num_sampled]
        # Apply X*W'+B, which yields [batch_size, num_sampled]
        sampled_logits = math_ops.matmul(inputs, sampled_w,
                                         transpose_b=True) + sampled_b

        if remove_accidental_hits:
            acc_hits = candidate_sampling_ops.compute_accidental_hits(
                labels, sampled, num_true=num_true)
            acc_indices, acc_ids, acc_weights = acc_hits

            # This is how SparseToDense expects the indices.
            acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
            acc_ids_2d_int32 = array_ops.reshape(
                math_ops.cast(acc_ids, types.int32), [-1, 1])
            sparse_indices = array_ops.concat(
                1, [acc_indices_2d, acc_ids_2d_int32], "sparse_indices")
            # Create sampled_logits_shape = [batch_size, num_sampled]
            sampled_logits_shape = array_ops.concat(0, [
                array_ops.shape(labels)[:1],
                array_ops.expand_dims(num_sampled, 0)
            ])
            sampled_logits += sparse_ops.sparse_to_dense(
                sparse_indices, sampled_logits_shape, acc_weights, 0.0)

        if subtract_log_q:
            # Subtract log of Q(l), prior probability that l appears in sampled.
            true_logits -= math_ops.log(true_expected_count)
            sampled_logits -= math_ops.log(sampled_expected_count)

        # Construct output logits and labels. The true labels/logits start at col 0.
        out_logits = array_ops.concat(1, [true_logits, sampled_logits])
        # true_logits is a float tensor, ones_like(true_logits) is a float tensor
        # of ones. We then divide by num_true to ensure the per-example labels sum
        # to 1.0, i.e. form a proper probability distribution.
        out_labels = array_ops.concat(1, [
            array_ops.ones_like(true_logits) / num_true,
            array_ops.zeros_like(sampled_logits)
        ])

    return out_logits, out_labels
Example #7
0
    def _compute_sampled_logits(self,
                                weights,
                                biases,
                                labels,
                                inputs,
                                num_sampled,
                                num_classes,
                                transmissibility,
                                num_true=1,
                                sampled_values=None,
                                subtract_log_q=True,
                                remove_accidental_hits=False,
                                partition_strategy="mod",
                                name=None,
                                seed=None):
        if isinstance(weights, variables.PartitionedVariable):
            weights = list(weights)
        if not isinstance(weights, list):
            weights = [weights]

        with ops.name_scope(name, "compute_sampled_logits",
                            weights + [biases, inputs, labels]):
            if labels.dtype != dtypes.int64:
                labels = math_ops.cast(labels, dtypes.int64)
            if labels.shape.ndims == 1:
                labels = array_ops.expand_dims(labels, -1)
            labels_flat = array_ops.reshape(labels, [-1])

            # Sample the negative labels.
            #   sampled shape: [num_sampled] tensor
            #   true_expected_count shape = [batch_size, 1] tensor
            #   sampled_expected_count shape = [num_sampled] tensor
            #   num_sampled 字典大小
            if sampled_values is None:
                sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
                    true_classes=labels,
                    num_true=num_true,
                    num_sampled=num_sampled,
                    unique=True,
                    range_max=num_classes,
                    seed=seed)
            # NOTE: pylint cannot tell that 'sampled_values' is a sequence
            # pylint: disable=unpacking-non-sequence
            sampled, true_expected_count, sampled_expected_count = (
                array_ops.stop_gradient(s) for s in sampled_values)
            # pylint: enable=unpacking-non-sequence
            sampled = math_ops.cast(sampled, dtypes.int64)

            # labels_flat is a [batch_size * num_true] tensor
            # sampled is a [num_sampled] int tensor
            all_ids = array_ops.concat([labels_flat, sampled], 0)

            # Retrieve the true weights and the logits of the sampled weights.

            # weights shape is [num_classes, dim]
            # 128个相似节点对和 5个非相似节点(也就是128*5个非相似节点对)
            all_w = embedding_ops.embedding_lookup(
                weights, all_ids, partition_strategy=partition_strategy)

            # true_w shape is [batch_size * num_true, dim] - > 128 * 100
            true_w = array_ops.slice(
                all_w, [0, 0],
                array_ops.stack([array_ops.shape(labels_flat)[0], -1]))
            # 5 * 100
            sampled_w = array_ops.slice(
                all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]),
                [-1, -1])

            # inputs has shape [batch_size, dim]
            # sampled_w has shape [num_sampled, dim]
            # Apply X*W', which yields [batch_size, num_sampled]
            # 128个输入节点分别和这5个非相似节点,进行比较, 128 * 5, 表示节点a和节点b的相似度.
            sampled_logits = math_ops.matmul(inputs,
                                             sampled_w,
                                             transpose_b=True)

            # Retrieve the true and sampled biases, compute the true logits, and
            # add the biases to the true and sampled logits.
            all_b = embedding_ops.embedding_lookup(
                biases, all_ids, partition_strategy=partition_strategy)
            # true_b is a [batch_size * num_true] tensor
            # sampled_b is a [num_sampled] float tensor
            true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
            sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat),
                                        [-1])

            # inputs shape is [batch_size, dim]
            # true_w shape is [batch_size * num_true, dim]
            # row_wise_dots is [batch_size, num_true, dim]
            dim = array_ops.shape(true_w)[1:2]
            new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)
            row_wise_dots = math_ops.multiply(
                array_ops.expand_dims(inputs, 1),
                array_ops.reshape(true_w, new_true_w_shape))
            # We want the row-wise dot plus biases which yields a
            # [batch_size, num_true] tensor of true_logits.
            dots_as_matrix = array_ops.reshape(
                row_wise_dots, array_ops.concat([[-1], dim], 0))
            true_logits = array_ops.reshape(self._sum_rows(dots_as_matrix),
                                            [-1, num_true])
            true_b = array_ops.reshape(true_b, [-1, num_true])
            # 相似节点对,对比结果是128*1;非相似节点对,对比结果是128*5
            true_logits += true_b
            sampled_logits += sampled_b

            if remove_accidental_hits:
                acc_hits = candidate_sampling_ops.compute_accidental_hits(
                    labels, sampled, num_true=num_true)
                acc_indices, acc_ids, acc_weights = acc_hits

                # This is how SparseToDense expects the indices.
                acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
                acc_ids_2d_int32 = array_ops.reshape(
                    math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
                sparse_indices = array_ops.concat(
                    [acc_indices_2d, acc_ids_2d_int32], 1, "sparse_indices")
                # Create sampled_logits_shape = [batch_size, num_sampled]
                sampled_logits_shape = array_ops.concat([
                    array_ops.shape(labels)[:1],
                    array_ops.expand_dims(num_sampled, 0)
                ], 0)
                if sampled_logits.dtype != acc_weights.dtype:
                    acc_weights = math_ops.cast(acc_weights,
                                                sampled_logits.dtype)
                sampled_logits += sparse_ops.sparse_to_dense(
                    sparse_indices,
                    sampled_logits_shape,
                    acc_weights,
                    default_value=0.0,
                    validate_indices=False)

            if subtract_log_q:
                # Subtract log of Q(l), prior probability that l appears in sampled.
                true_logits -= math_ops.log(true_expected_count)
                sampled_logits -= math_ops.log(sampled_expected_count)

            # Construct output logits and labels. The true labels/logits start at col 0.
            out_logits = array_ops.concat([true_logits, sampled_logits], 1)

            # true_logits is a float tensor, ones_like(true_logits) is a float
            # tensor of ones. We then divide by num_true to ensure the per-example
            # labels sum to 1.0, i.e. form a proper probability distribution.
            out_labels = array_ops.concat(
                [
                    transmissibility,  # array_ops.ones_like(true_logits) / num_true,  #
                    array_ops.zeros_like(sampled_logits)
                ],
                1)

            return out_logits, out_labels
Example #8
0
def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
                            num_classes, num_true=1,
                            sampled_values=None,
                            subtract_log_q=True,
                            remove_accidental_hits=False,
                            name=None):
  """Helper function for nce_loss and sampled_softmax_loss functions.

  Computes sampled output training logits and labels suitable for implementing
  e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see
  sampled_softmax_loss).

  Note: In the case where num_true > 1, we assign to each target class
  the target probability 1 / num_true so that the target probabilities
  sum to 1 per-example.

  Args:
    weights: tensor of label embeddings with shape = [num_classes, dim]
    biases: tensor of num_classes label biases
    inputs: tensor with shape = [batch_size, dim] corresponding to forward
        activations of the input network
    labels: int tensor with shape [batch_size, num_true]
    num_sampled: number of label classes to sample per batch
    num_classes: number of possible label classes in the data (e.g. vocab size)
    num_true: number of target classes per example (default: 1)
    sampled_values: a tuple of (sampled_candidates, true_expected_count,
        sampled_expected_count) returned by a *CandidateSampler function to use
        (if None, we default to LogUniformCandidateSampler)
    subtract_log_q: subtract the log expected count of the labels in the sample
        to get the logits of the true labels (default: True)
        Turn off for Negative Sampling.
    remove_accidental_hits: whether to remove "accidental hits" where a sampled
        label equals the true labels (bool, default: False)
    name: name for this op

  Returns:
    out_logits, out_labels: tensors with shape [batch_size, num_true +
        num_sampled] for passing to either SigmoidCrossEntropyWithLogits (NCE)
        or SoftmaxCrossEntropyWithLogits (sampled softmax).

  """

  with ops.op_scope(
      [weights, biases, inputs, labels], name, "compute_sampled_logits"):
    if labels.dtype != types.int64:
      labels = math_ops.cast(labels, types.int64)
    labels_flat = array_ops.reshape(labels, [-1])

    # Sample the negative labels.
    #   sampled shape: num_sampled vector
    #   true_expected_count shape = [batch_size, 1]
    #   sampled_expected_count shape = num_sampled vector
    if sampled_values is None:
      sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
          true_classes=labels,
          num_true=num_true,
          num_sampled=num_sampled,
          unique=True,
          range_max=num_classes)
    # NOTE: pylint cannot tell that 'sampled_values' is a sequence
    # pylint: disable=unpacking-non-sequence
    sampled, true_expected_count, sampled_expected_count = sampled_values
    # pylint: enable=unpacking-non-sequence

    # weights shape is [num_classes, dim]
    # labels_flat is a [batch_size * num_true] vector
    # true_w shape is [batch_size * num_true, dim]
    # true_b is a [batch_size * num_true] vector
    true_w = embedding_ops.embedding_lookup(weights, labels_flat)
    true_b = embedding_ops.embedding_lookup(biases, labels_flat)

    # inputs shape is [batch_size, dim]
    # true_w shape is [batch_size * num_true, dim]
    # row_wise_dots is [batch_size, num_true, dim]
    dim = array_ops.shape(true_w)[1:2]
    new_true_w_shape = array_ops.concat(0, [[-1, num_true], dim])
    row_wise_dots = math_ops.mul(
        array_ops.expand_dims(inputs, 1),
        array_ops.reshape(true_w, new_true_w_shape))
    # We want the row-wise dot plus biases which yields a
    # [batch_size, num_true] tensor of true_logits.
    dots_as_matrix = array_ops.reshape(row_wise_dots,
                                       array_ops.concat(0, [[-1], dim]))
    true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
    true_b = array_ops.reshape(true_b, [-1, num_true])
    true_logits += true_b

    # Lookup weights and biases for sampled labels.
    #   sampled is a num_sampled int vector
    #   sampled_w shape is [num_sampled, dim]
    #   sampled_b is a num_sampled float vector
    sampled_w = embedding_ops.embedding_lookup(weights, sampled)
    sampled_b = embedding_ops.embedding_lookup(biases, sampled)

    # inputs has shape [batch_size, dim]
    # sampled_w has shape [num_sampled, dim]
    # sampled_b has shape [num_sampled]
    # Apply X*W'+B, which yields [batch_size, num_sampled]
    sampled_logits = math_ops.matmul(inputs,
                                     sampled_w,
                                     transpose_b=True) + sampled_b

    if remove_accidental_hits:
      acc_hits = candidate_sampling_ops.compute_accidental_hits(
          labels, sampled, num_true=num_true)
      acc_indices, acc_ids, acc_weights = acc_hits

      # This is how SparseToDense expects the indices.
      acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
      acc_ids_2d_int32 = array_ops.reshape(math_ops.cast(
          acc_ids, types.int32), [-1, 1])
      sparse_indices = array_ops.concat(
          1, [acc_indices_2d, acc_ids_2d_int32], "sparse_indices")
      # Create sampled_logits_shape = [batch_size, num_sampled]
      sampled_logits_shape = array_ops.concat(
          0,
          [array_ops.shape(labels)[:1], array_ops.expand_dims(num_sampled, 0)])
      sampled_logits += sparse_ops.sparse_to_dense(
          sparse_indices, sampled_logits_shape, acc_weights, 0.0)

    if subtract_log_q:
      # Subtract log of Q(l), prior probability that l appears in sampled.
      true_logits -= math_ops.log(true_expected_count)
      sampled_logits -= math_ops.log(sampled_expected_count)

    # Construct output logits and labels. The true labels/logits start at col 0.
    out_logits = array_ops.concat(1, [true_logits, sampled_logits])
    # true_logits is a float tensor, ones_like(true_logits) is a float tensor
    # of ones. We then divide by num_true to ensure the per-example labels sum
    # to 1.0, i.e. form a proper probability distribution.
    out_labels = array_ops.concat(
        1, [array_ops.ones_like(true_logits) / num_true,
            array_ops.zeros_like(sampled_logits)])

  return out_logits, out_labels
Example #9
0
def _compute_sampled_logits(ri_tensors,
                            weights,
                            bias,
                            labels,
                            partition_const,
                            inputs,
                            num_sampled,
                            num_classes,
                            num_true=1,
                            sampled_values=None,
                            subtract_log_q=True,
                            remove_accidental_hits=False,
                            partition_strategy="mod",
                            name=None,
                            seed=None):
    if isinstance(weights, variables.PartitionedVariable):
        weights = list(weights)
    if not isinstance(weights, list):
        weights = [weights]

    with ops.name_scope(name, "compute_sampled_logits",
                        weights + [inputs, labels]):
        if labels.dtype != dtypes.int64:
            labels = math_ops.cast(labels, dtypes.int64)
        labels_flat = array_ops.reshape(labels, [-1])

        # Sample the negative labels.
        #   sampled shape: [num_sampled] tensor
        #   true_expected_count shape = [batch_size, 1] tensor
        #   sampled_expected_count shape = [num_sampled] tensor
        if sampled_values is None:
            sampled_values = candidate_sampling_ops.uniform_candidate_sampler(
                true_classes=labels,
                num_true=num_true,
                num_sampled=num_sampled,
                unique=True,
                range_max=num_classes,
                seed=seed)
        # NOTE: pylint cannot tell that 'sampled_values' is a sequence
        # pylint: disable=unpacking-non-sequence
        sampled, true_expected_count, sampled_expected_count = (
            array_ops.stop_gradient(s) for s in sampled_values)
        # pylint: enable=unpacking-non-sequence
        sampled = math_ops.cast(sampled, dtypes.int64)

        # labels_flat is a [batch_size * num_true] tensor
        # sampled is a [num_sampled] int tensor
        all_ids = array_ops.concat([labels_flat, sampled], 0)

        # true_ris
        true_ris = tx.gather_sparse(sp_tensor=ri_tensors, ids=labels_flat)
        sampled_ris = tx.gather_sparse(sp_tensor=ri_tensors, ids=sampled)

        true_w = embedding_lookup_sparse(params=weights,
                                         sp_ids=tx.sparse_indices(true_ris),
                                         sp_weights=true_ris,
                                         combiner="sum",
                                         partition_strategy=partition_strategy)

        noise_w = embedding_lookup_sparse(params=weights,
                                          sp_ids=tx.sparse_indices(sampled_ris),
                                          sp_weights=sampled_ris,
                                          combiner="sum",
                                          partition_strategy=partition_strategy)

        if bias is not None:
            sampled_b = embedding_lookup_sparse(
                params=bias,
                sp_ids=tx.sparse_indices(sampled_ris),
                sp_weights=sampled_ris,
                combiner="sum",
                partition_strategy=partition_strategy)

            true_b = embedding_lookup_sparse(
                params=bias,
                sp_ids=tx.sparse_indices(true_ris),
                sp_weights=true_ris,
                combiner="sum",
                partition_strategy=partition_strategy)

        noise_logits = math_ops.matmul(inputs, noise_w, transpose_b=True)

        dim = array_ops.shape(true_w)[1:2]
        new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)
        true_w_e = array_ops.reshape(true_w, new_true_w_shape)

        row_wise_dots = math_ops.multiply(array_ops.expand_dims(inputs, 1),
                                          true_w_e)
        # We want the row-wise dot plus biases which yields a
        # [batch_size, num_true] tensor of true_logits.
        dots_as_matrix = array_ops.reshape(row_wise_dots,
                                           array_ops.concat([[-1], dim], 0))
        true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])

        if bias is not None:
            true_b = array_ops.reshape(true_b, [-1, num_true])
            true_logits += true_b
            noise_logits += sampled_b

        # TODO  need to review how to do this Z
        # true_logits = true_logits * math_ops.exp(partition_const)

        if remove_accidental_hits:
            acc_hits = candidate_sampling_ops.compute_accidental_hits(
                labels, sampled, num_true=num_true)
            acc_indices, acc_ids, acc_weights = acc_hits

            # This is how SparseToDense expects the indices.
            acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
            acc_ids_2d_int32 = array_ops.reshape(
                math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
            sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1,
                                              "sparse_indices")
            # Create sampled_logits_shape = [batch_size, num_sampled]
            sampled_logits_shape = array_ops.concat(
                [array_ops.shape(labels)[:1],
                 array_ops.expand_dims(num_sampled, 0)], 0)
            if noise_logits.dtype != acc_weights.dtype:
                acc_weights = math_ops.cast(acc_weights, noise_logits.dtype)
            noise_logits += sparse_ops.sparse_to_dense(
                sparse_indices,
                sampled_logits_shape,
                acc_weights,
                default_value=0.0,
                validate_indices=False)

        if subtract_log_q:
            # Subtract log of Q(l), prior probability that l appears in sampled.
            true_logits -= math_ops.log(true_expected_count)
            noise_logits -= math_ops.log(sampled_expected_count)

        # Construct output logits and labels. The true labels/logits start at col 0.
        out_logits = array_ops.concat([true_logits, noise_logits], 1)

        # true_logits is a float tensor, ones_like(true_logits) is a float
        # tensor of ones. We then divide by num_true to ensure the per-example
        # labels sum to 1.0, i.e. form a proper probability distribution.
        out_labels = array_ops.concat([
            array_ops.ones_like(true_logits) / num_true,
            array_ops.zeros_like(noise_logits)
        ], 1)

        # out_logits = math_ops.div(out_logits,math_ops.exp(partition_const))
        # out_logits = out_logits / (partition_const + 1)
        return out_logits, out_labels
Example #10
0
def _compute_sampled_logits(weights,
                            biases,
                            inputs,
                            labels,
                            num_sampled,
                            num_classes,
                            num_true=1,
                            sampled_values=None,
                            subtract_log_q=True,
                            remove_accidental_hits=False,
                            partition_strategy="mod",
                            name=None):

    if not isinstance(weights, list):
        weights = [weights]

    with ops.name_scope(name, "compute_sampled_logits",
                        weights + [biases, inputs, labels]):
        if labels.dtype != dtypes.int64:
            labels = math_ops.cast(labels, dtypes.int64)
        labels_flat = array_ops.reshape(labels, [-1])
        # Sample the negative labels.
        #   sampled shape: [num_sampled] tensor
        #   true_expected_count shape = [batch_size, 1] tensor
        #   sampled_expected_count shape = [num_sampled] tensor
        if sampled_values is None:
            sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
                true_classes=labels,
                num_true=num_true,
                num_sampled=num_sampled,
                unique=True,
                range_max=num_classes)
        # NOTE: pylint cannot tell that 'sampled_values' is a sequence
        # pylint: disable=unpacking-non-sequence
        sampled, true_expected_count, sampled_expected_count = sampled_values
        # pylint: enable=unpacking-non-sequence

        # labels_flat is a [batch_size * num_true] tensor
        # sampled is a [num_sampled] int tensor
        all_ids = array_ops.concat(0, [labels_flat, sampled])

        # weights shape is [num_classes, dim]
        all_w = embedding_ops.embedding_lookup(
            weights, all_ids, partition_strategy=partition_strategy)
        all_b = embedding_ops.embedding_lookup(biases, all_ids)
        # true_w shape is [batch_size * num_true, dim]
        # true_b is a [batch_size * num_true] tensor
        true_w = array_ops.slice(all_w, [0, 0],
                                 array_ops.pack(
                                     [array_ops.shape(labels_flat)[0],
                                      -1]))  # 128*128
        true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))

        # inputs shape is [batch_size, dim]
        # true_w shape is [batch_size * num_true, dim]
        # row_wise_dots is [batch_size, num_true, dim]
        dim = array_ops.shape(true_w)[1:2]
        new_true_w_shape = array_ops.concat(0, [[-1, num_true], dim])
        row_wise_dots = math_ops.mul(
            array_ops.expand_dims(inputs, 1),  # 128*1*128
            array_ops.reshape(true_w, new_true_w_shape))  # 128*1*128
        # We want the row-wise dot plus biases which yields a
        # [batch_size, num_true] tensor of true_logits.
        dots_as_matrix = array_ops.reshape(row_wise_dots,
                                           array_ops.concat(0, [[-1], dim]))
        true_logits = array_ops.reshape(_sum_rows(dots_as_matrix),
                                        [-1, num_true])
        true_b = array_ops.reshape(true_b, [-1, num_true])
        true_logits += true_b

        # Lookup weights and biases for sampled labels.
        #   sampled_w shape is [num_sampled, dim]
        #   sampled_b is a [num_sampled] float tensor
        sampled_w = array_ops.slice(
            all_w, array_ops.pack([array_ops.shape(labels_flat)[0], 0]),
            [-1, -1])
        sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])

        # inputs has shape [batch_size, dim]
        # sampled_w has shape [num_sampled, dim]
        # sampled_b has shape [num_sampled]
        # Apply X*W'+B, which yields [batch_size, num_sampled]
        sampled_logits = math_ops.matmul(inputs, sampled_w,
                                         transpose_b=True) + sampled_b
        if remove_accidental_hits:
            acc_hits = candidate_sampling_ops.compute_accidental_hits(
                labels, sampled, num_true=num_true)
            acc_indices, acc_ids, acc_weights = acc_hits

            # This is how SparseToDense expects the indices.
            acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
            acc_ids_2d_int32 = array_ops.reshape(
                math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
            sparse_indices = array_ops.concat(
                1, [acc_indices_2d, acc_ids_2d_int32], "sparse_indices")
            # Create sampled_logits_shape = [batch_size, num_sampled]
            sampled_logits_shape = array_ops.concat(0, [
                array_ops.shape(labels)[:1],
                array_ops.expand_dims(num_sampled, 0)
            ])
            if sampled_logits.dtype != acc_weights.dtype:
                acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
            sampled_logits += sparse_ops.sparse_to_dense(
                sparse_indices,
                sampled_logits_shape,
                acc_weights,
                default_value=0.0,
                validate_indices=False)

        if subtract_log_q:
            # Subtract log of Q(l), prior probability that l appears in sampled.
            true_logits -= math_ops.log(true_expected_count)
            sampled_logits -= math_ops.log(sampled_expected_count)

        # Construct output logits and labels. The true labels/logits start at col 0.
        out_logits = array_ops.concat(1, [true_logits, sampled_logits])
        # true_logits is a float tensor, ones_like(true_logits) is a float tensor
        # of ones. We then divide by num_true to ensure the per-example labels sum
        # to 1.0, i.e. form a proper probability distribution.
        out_labels = array_ops.concat(1, [
            array_ops.ones_like(true_logits) / num_true,
            array_ops.zeros_like(sampled_logits)
        ])
    return out_logits, out_labels