Exemple #1
0
 def _apply_sparse(self, g_t, x_tm1, prepare):
   """"""
   
   idxs, idxs_ = array_ops.unique(g_t.indices)
   g_t_ = math_ops.unsorted_segment_sum(g_t.values, idxs_, array_ops.size(idxs))
   updates = []
   
   if self._mu > 0:
     m_and_t = self._sparse_moving_average(x_tm1, idxs, g_t_, 'm', self._mu)
     m_t_ = array_ops.gather(m_and_t[0], idxs)
     gamma_t = ops.convert_to_tensor(self._gamma)
     m_bar_t_ = (1-gamma_t)*m_t_ + gamma_t*g_t_
     updates.extend(m_and_t)
   else:
     m_bar_t_ = g_t_
   
   if self._ups > 0:
     v_and_t = self._sparse_moving_average(x_tm1, idxs, g_t_**2, 'v', self._ups)
     v_t_ = array_ops.gather(v_and_t[0], idxs)
     eps_t = ops.convert_to_tensor(self._eps)
     v_bar_t_ = math_ops.sqrt(v_t_ + eps_t)
     updates.extend(v_and_t)
   else:
     v_bar_t_ = 1.
   
   lr_t = ops.convert_to_tensor(self._lr)
   s_t_ = lr_t * m_bar_t_ / v_bar_t_
   return [[s_t_, x_tm1, idxs, g_t]] + updates
Exemple #2
0
 def approximate_hessian(self, grads_and_vars, name=None):
   """
   I haven't tested this yet so I have no idea if it works, but even if it
   does it's probably super slow, and either way nothing else has been modified
   to deal with it.
   """
   
   gv = 0
   var_refs = []
   for g_t, x_tm1 in grads_and_vars:
     var_refs.append(x_tm1.ref())
     if g_t is None:
       continue
     with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
       if isinstance(g_t, ops.Tensor):
         gv += math_ops.reduce_sum(g_t * random_ops.random_normal(g_t.get_shape()))
       else:
         idxs, idxs_ = array_ops.unique(g_t.indices)
         g_t_ = math_ops.unsorted_segment_sum(g_t.values, idxs_, array_ops.size(idxs))
         gv += math_ops.reduce_sum(g_t_ * random_ops.random_normal(g_t_.get_shape()))
   hesses = gradients.gradients(gv, var_refs,
                                gate_gradients=(gate_gradients == Optimizer.GATE_OP),
                                aggregation_method=aggregation_method,
                                colocate_gradients_with_ops=colocate_gradients_with_ops)
   return zip([g_t for g_t, _ in grads_and_vars], [x_tm1 for _, x_tm1 in grads_and_vars], hesses)
 def _apply_sparse(self, grad, var):
   if len(grad.indices.get_shape()) == 1:
     grad_indices = grad.indices
     grad_values = grad.values
   else:
     grad_indices = array_ops.reshape(grad.indices, [-1])
     grad_values = array_ops.reshape(grad.values, [-1, grad.values.get_shape()[-1].value])
   gidxs, metagidxs = array_ops.unique(grad_indices)
   sizegidxs = array_ops.size(gidxs)
   gvals = math_ops.unsorted_segment_sum(grad_values, metagidxs, sizegidxs)
   # m_t = mu * m + (1 - mu) * g_t
   m = self.get_slot(var, "m")
   m_scaled_g_values = gvals * (1 - self._mu_t)
   m_t = state_ops.scatter_update(m, gidxs,
                                  array_ops.gather(m, gidxs) * self._mu_t,
                                  use_locking=self._use_locking)
   m_t = state_ops.scatter_add(m_t, gidxs, m_scaled_g_values,
                               use_locking=self._use_locking)
   m_t_ = array_ops.gather(m_t, gidxs) / (1 - self._mu2_t * self._mu_power)
   # m_bar = mu * m_t + (1 - mu) * g_t
   m_bar = self._mu2_t * m_t_ + m_scaled_g_values / (1 - self._mu_power)
   var_update = state_ops.scatter_sub(var, gidxs,
                                    self._lr_t * m_bar,
                                    use_locking=self._use_locking)
   return control_flow_ops.group(*[var_update, m_t])
Exemple #4
0
 def _prepare(self, grads_and_vars):
   """"""
   
   if self._lr is None:
     sTy = 0
     sTs = 0
     yTy = 0
     for g_t, x_tm1 in grads_and_vars:
       if g_t is None:
         continue
       with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
         if isinstance(g_t, ops.Tensor):
           g_tm1 = self.get_slot(x_tm1, 'g')
           s_tm1 = self.get_slot(x_tm1, 's')
           y_t = (g_t-g_tm1)
           sTy += math_ops.reduce_sum(s_tm1*y_t)
           sTs += math_ops.reduce_sum(s_tm1**2)
           yTy += math_ops.reduce_sum(y_t**2)
         else:
           idxs, idxs_ = array_ops.unique(g_t.indices)
           g_t_ = math_ops.unsorted_segment_sum(g_t.values, idxs_, array_ops.size(idxs))
           g_tm1 = self.get_slot(x_tm1, 'g')
           g_tm1_ = array_ops.gather(g_tm1, idxs)
           s_tm1 = self.get_slot(x_tm1, 's')
           s_tm1_ = array_ops.gather(s_tm1, idxs)
           y_t_ = (g_t_-g_tm1_)
           sTy += math_ops.reduce_sum(s_tm1_*y_t_)
           sTs += math_ops.reduce_sum(s_tm1_**2)
           yTy += math_ops.reduce_sum(y_t_**2)
     sTy = math_ops.abs(sTy)
     self._lr = sTs / (sTy + self._eps)
  def quantiles_ready():
    """The subgraph for when the quantiles are ready."""
    quantized_feature = quantile_ops.quantiles([sparse_column_values], [],
                                               [quantile_buckets], [])
    quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64)
    quantized_feature = array_ops.reshape(quantized_feature, [-1])
    example_indices, _ = array_ops.split(
        sparse_column_indices, num_or_size_splits=2, axis=1)
    example_indices = array_ops.squeeze(example_indices, [1])
    filtered_gradients = array_ops.gather(gradients, example_indices)
    filtered_hessians = array_ops.gather(hessians, example_indices)
    filtered_partition_ids = array_ops.gather(example_partition_ids,
                                              example_indices)
    unique_partitions, mapped_partitions = array_ops.unique(
        example_partition_ids)

    # Compute aggregate stats for each partition.
    per_partition_gradients = math_ops.unsorted_segment_sum(
        gradients, mapped_partitions, array_ops.size(unique_partitions))
    per_partition_hessians = math_ops.unsorted_segment_sum(
        hessians, mapped_partitions, array_ops.size(unique_partitions))

    # Prepend a bias feature per partition that accumulates the stats for all
    # examples in that partition.
    bias_feature_ids = array_ops.fill(
        array_ops.shape(unique_partitions), _BIAS_FEATURE_ID)
    bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64)
    partition_ids = array_ops.concat(
        [unique_partitions, filtered_partition_ids], 0)
    filtered_gradients = array_ops.concat(
        [per_partition_gradients, filtered_gradients], 0)
    filtered_hessians = array_ops.concat(
        [per_partition_hessians, filtered_hessians], 0)
    bucket_ids = array_ops.concat([bias_feature_ids, quantized_feature], 0)
    return partition_ids, bucket_ids, filtered_gradients, filtered_hessians
def embedding_lookup_unique(params, ids, name=None):
  """Version of embedding_lookup that avoids duplicate lookups.

  This can save communication in the case of repeated ids.
  Same interface as embedding_lookup. Except it supports multi-dimensional `ids`
  which allows to not reshape input/output to fit gather.

  Args:
    params: A list of tensors with the same shape and type, or a
      `PartitionedVariable`. Shape `[index, d1, d2, ...]`.
    ids: A one-dimensional `Tensor` with type `int32` or `int64` containing
      the ids to be looked up in `params`. Shape `[ids1, ids2, ...]`.
    name: A name for this operation (optional).

  Returns:
    A `Tensor` with the same type as the tensors in `params` and dimension of
    `[ids1, ids2, d1, d2, ...]`.

  Raises:
    ValueError: If `params` is empty.
  """
  with ops.name_scope(name, "EmbeddingLookupUnique", [params, ids]):
    ids = ops.convert_to_tensor(ids)
    shape = array_ops.shape(ids)
    ids_flat = array_ops.reshape(
        ids, math_ops.reduce_prod(shape, keep_dims=True))
    unique_ids, idx = array_ops.unique(ids_flat)
    unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids)
    embeds_flat = array_ops.gather(unique_embeddings, idx)
    embed_shape = array_ops.concat(
        [shape, array_ops.shape(unique_embeddings)[1:]], 0)
    embeds = array_ops.reshape(embeds_flat, embed_shape)
    embeds.set_shape(ids.get_shape().concatenate(
        unique_embeddings.get_shape()[1:]))
    return embeds
Exemple #7
0
 def _apply_sparse(self, g_t, x_tm1, prepare):
   """"""
   
   idxs, idxs_ = array_ops.unique(g_t.indices)
   g_t_ = math_ops.unsorted_segment_sum(g_t.values, idxs_, array_ops.size(idxs))
   
   s_t_ = self._lr * g_t_
   return [[s_t_, x_tm1, idxs, g_t_]]
  def testInt32OutIdxInt64(self):
    x = np.random.randint(2, high=10, size=7000)
    with self.test_session() as sess:
      y, idx = array_ops.unique(x, out_idx=dtypes.int64)
      tf_y, tf_idx = sess.run([y, idx])

    self.assertEqual(len(x), len(tf_idx))
    self.assertEqual(len(tf_y), len(np.unique(x)))
    for i in range(len(x)):
      self.assertEqual(x[i], tf_y[tf_idx[i]])
Exemple #9
0
  def testInt32(self):
    x = np.random.randint(2, high=10, size=7000)
    with self.cached_session() as sess:
      y, idx = array_ops.unique(x)
      tf_y, tf_idx = self.evaluate([y, idx])

    self.assertEqual(len(x), len(tf_idx))
    self.assertEqual(len(tf_y), len(np.unique(x)))
    for i in range(len(x)):
      self.assertEqual(x[i], tf_y[tf_idx[i]])
    def active_inputs():
      """The normal flow when the handler is active."""
      # Remove the second column of example indices matrix since it is not
      # useful.
      example_indices, _ = array_ops.split(
          self._sparse_int_column.indices, num_or_size_splits=2, axis=1)
      example_indices = array_ops.squeeze(example_indices, [1])

      filtered_gradients = array_ops.gather(gradients, example_indices)
      filtered_hessians = array_ops.gather(hessians, example_indices)
      filtered_partition_ids = array_ops.gather(example_partition_ids,
                                                example_indices)
      unique_partitions, mapped_partitions = array_ops.unique(
          example_partition_ids)

      # Compute aggregate stats for each partition.
      # The bias is computed on gradients and hessians (and not
      # filtered_gradients) which have exactly one value per example, so we
      # don't double count a gradient in multivalent columns.
      # Since unsorted_segment_sum can be numerically unstable, use 64bit
      # operation.
      gradients64 = math_ops.cast(gradients, dtypes.float64)
      hessians64 = math_ops.cast(hessians, dtypes.float64)
      per_partition_gradients = math_ops.unsorted_segment_sum(
          gradients64, mapped_partitions, array_ops.size(unique_partitions))
      per_partition_hessians = math_ops.unsorted_segment_sum(
          hessians64, mapped_partitions, array_ops.size(unique_partitions))
      per_partition_gradients = math_ops.cast(per_partition_gradients,
                                              dtypes.float32)
      per_partition_hessians = math_ops.cast(per_partition_hessians,
                                             dtypes.float32)
      # Prepend a bias feature per partition that accumulates the stats for all
      # examples in that partition.
      # Bias is added to the stats even if there are no examples with values in
      # the current sparse column. The reason is that the other example batches
      # might have values in these partitions so we have to keep the bias
      # updated.
      bias_feature_ids = array_ops.fill(
          array_ops.shape(unique_partitions), _BIAS_FEATURE_ID)
      bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64)
      partition_ids = array_ops.concat(
          [unique_partitions, filtered_partition_ids], 0)
      filtered_gradients = array_ops.concat(
          [per_partition_gradients, filtered_gradients], 0)
      filtered_hessians = array_ops.concat(
          [per_partition_hessians, filtered_hessians], 0)
      feature_ids = array_ops.concat(
          [bias_feature_ids, self._sparse_int_column.values], 0)
      # Dimension is always zero for sparse int features.
      dimension_ids = array_ops.zeros_like(feature_ids, dtype=dtypes.int64)
      feature_ids_and_dimensions = array_ops.stack(
          [feature_ids, dimension_ids], axis=1)
      return (partition_ids, feature_ids_and_dimensions, filtered_gradients,
              filtered_hessians)
  def testString(self):
    indx = np.random.randint(65, high=122, size=7000)
    x = [chr(i) for i in indx]
    with self.test_session() as sess:
      y, idx = array_ops.unique(x)
      tf_y, tf_idx = sess.run([y, idx])

    self.assertEqual(len(x), len(tf_idx))
    self.assertEqual(len(tf_y), len(np.unique(x)))
    for i in range(len(x)):
      self.assertEqual(x[i], tf_y[tf_idx[i]].decode('ascii'))
  def _aggregate_sparse_grad(self, grad, var, train_ops):
    """Aggregate sparse gradients.

    Args:
      grad: The sparse gradient to aggregate.
      var: The variable to apply this gradient to.
      train_ops: The train_ops for the worker to run.

    Returns:
      aggregated_grad: Aggregated grad.
    """
    # Sparse gradients have to be inserted as one pair of (value,
    # indice) as an element instead of the whole "indexedslice" because
    # their shapes are not deterministic.
    sparse_grad_queue = (data_flow_ops.FIFOQueue(
        -1,
        (grad.values.dtype, grad.indices.dtype),
        shapes=(var.get_shape().as_list()[1:], ()),
        shared_name="sparse_grad_q_%s" % var.name))
    self._sparse_grad_queues_and_devs.append((sparse_grad_queue, var.device))

    # Sparse token is inserted after the "enqueue_many" finishes. This
    # is needed to make sure enough sparse gradients have been enqueued
    # before applying them to the variables.
    sparse_token_queue = (data_flow_ops.FIFOQueue(
        self._replicas_to_aggregate * 2,
        types_pb2.DT_INT32,
        shapes=(),
        shared_name="sparse_token_q_%s" % var.name))
    self._one_element_queue_list.append((sparse_token_queue, var.device))

    enqueue_spares_op = sparse_grad_queue.enqueue_many([grad.values,
                                                        grad.indices])
    with ops.control_dependencies([enqueue_spares_op]):
      train_ops.append(sparse_token_queue.enqueue((1,)))

    with ops.control_dependencies([sparse_token_queue.dequeue_many(
        self._replicas_to_aggregate)]):
      values, indices = sparse_grad_queue.dequeue_many(sparse_grad_queue.size())
      concat_grad = ops.IndexedSlices(values, indices, grad.dense_shape)

      # Sum the gradients of the same variables in the sparse layers so
      # that each variable is only updated once. Note that with 2
      # gradients g1 and g2 from 2 replicas for the same variable,
      # apply(g1+g2) is different from apply(g1) and then apply(g2) when
      # the optimizer is complex like Momentum or Adagrad.
      values = concat_grad.values
      indices = concat_grad.indices
      new_indices, indx = array_ops.unique(indices)
      num_indices = array_ops.shape(new_indices)[0]
      sum_values = math_ops.unsorted_segment_sum(values, indx, num_indices)
      return ops.IndexedSlices(sum_values, new_indices, concat_grad.dense_shape)
  def quantiles_ready():
    """The subgraph for when the quantiles are ready."""
    quantized_feature = quantile_ops.quantiles([], [sparse_column_values], [],
                                               [quantile_buckets],
                                               [sparse_column_indices])

    quantized_feature = math_ops.cast(quantized_feature[1], dtypes.int64)
    quantized_feature = array_ops.squeeze(quantized_feature, axis=0)

    example_indices, _ = array_ops.split(
        sparse_column_indices, num_or_size_splits=2, axis=1)
    example_indices = array_ops.squeeze(example_indices, [1])
    filtered_gradients = array_ops.gather(gradients, example_indices)
    filtered_hessians = array_ops.gather(hessians, example_indices)
    filtered_partition_ids = array_ops.gather(example_partition_ids,
                                              example_indices)
    unique_partitions, mapped_partitions = array_ops.unique(
        example_partition_ids)

    # Compute aggregate stats for each partition.
    # Since unsorted_segment_sum can be numerically unstable, use 64bit
    # operation.
    gradients64 = math_ops.cast(gradients, dtypes.float64)
    hessians64 = math_ops.cast(hessians, dtypes.float64)
    per_partition_gradients = math_ops.unsorted_segment_sum(
        gradients64, mapped_partitions, array_ops.size(unique_partitions))
    per_partition_hessians = math_ops.unsorted_segment_sum(
        hessians64, mapped_partitions, array_ops.size(unique_partitions))
    per_partition_gradients = math_ops.cast(per_partition_gradients,
                                            dtypes.float32)
    per_partition_hessians = math_ops.cast(per_partition_hessians,
                                           dtypes.float32)
    # Prepend a bias feature per partition that accumulates the stats for all
    # examples in that partition.
    bias_feature_ids = array_ops.fill(
        array_ops.shape(unique_partitions), _BIAS_FEATURE_ID)
    bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64)
    zeros = array_ops.zeros_like(bias_feature_ids)
    bias_feature_ids = array_ops.stack([bias_feature_ids, zeros], axis=1)

    partition_ids = array_ops.concat(
        [unique_partitions, filtered_partition_ids], 0)
    filtered_gradients = array_ops.concat(
        [per_partition_gradients, filtered_gradients], 0)
    filtered_hessians = array_ops.concat(
        [per_partition_hessians, filtered_hessians], 0)

    bucket_ids = array_ops.concat([bias_feature_ids, quantized_feature], 0)

    return partition_ids, bucket_ids, filtered_gradients, filtered_hessians
def update_all_medoids(pairwise_distances, predictions, labels, chosen_ids,
                       margin_multiplier, margin_type):
  """Updates all cluster medoids a cluster at a time.

  Args:
    pairwise_distances: 2-D Tensor of pairwise distances.
    predictions: 1-D Tensor of predicted cluster assignment.
    labels: 1-D Tensor of ground truth cluster assignment.
    chosen_ids: 1-D Tensor of cluster centroid indices.
    margin_multiplier: multiplication constant.
    margin_type: Type of structured margin to use. Default is nmi.

  Returns:
    chosen_ids: Updated 1-D Tensor of cluster centroid indices.
  """

  def func_cond_augmented_pam(iteration, chosen_ids):
    del chosen_ids  # Unused argument.
    return iteration < num_classes

  def func_body_augmented_pam(iteration, chosen_ids):
    """Call the update_medoid_per_cluster subroutine."""
    mask = math_ops.equal(
        math_ops.cast(predictions, dtypes.int64),
        math_ops.cast(iteration, dtypes.int64))
    this_cluster_ids = array_ops.where(mask)

    pairwise_distances_subset = array_ops.transpose(
        array_ops.gather(
            array_ops.transpose(
                array_ops.gather(pairwise_distances, this_cluster_ids)),
            this_cluster_ids))

    chosen_ids = update_medoid_per_cluster(pairwise_distances,
                                           pairwise_distances_subset, labels,
                                           chosen_ids, this_cluster_ids,
                                           iteration, margin_multiplier,
                                           margin_type)
    return iteration + 1, chosen_ids

  unique_class_ids = array_ops.unique(labels)[0]
  num_classes = array_ops.size(unique_class_ids)
  iteration = array_ops.constant(0)

  _, chosen_ids = control_flow_ops.while_loop(
      func_cond_augmented_pam, func_body_augmented_pam, [iteration, chosen_ids])
  return chosen_ids
Exemple #15
0
def _deduplicate_indexed_slices(values, indices):
  """Sums `values` associated with any non-unique `indices`.

  Args:
    values: A `Tensor` with rank >= 1.
    indices: A one-dimensional integer `Tensor`, indexing into the first
      dimension of `values` (as in an IndexedSlices object).
  Returns:
    A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
    de-duplicated version of `indices` and `summed_values` contains the sum of
    `values` slices associated with each unique index.
  """
  unique_indices, new_index_positions = array_ops.unique(indices)
  summed_values = math_ops.unsorted_segment_sum(
      values, new_index_positions,
      array_ops.shape(unique_indices)[0])
  return (summed_values, unique_indices)
def compute_augmented_facility_locations(pairwise_distances, labels, all_ids,
                                         margin_multiplier, margin_type):
  """Computes the centroid locations.

  Args:
    pairwise_distances: 2-D Tensor of pairwise distances.
    labels: 1-D Tensor of ground truth cluster assignment.
    all_ids: 1-D Tensor of all data indices.
    margin_multiplier: multiplication constant.
    margin_type: Type of structured margin to use. Default is nmi.

  Returns:
    chosen_ids: 1-D Tensor of chosen centroid indices.
  """

  def func_cond_augmented(iteration, chosen_ids):
    del chosen_ids  # Unused argument in func_cond_augmented.
    return iteration < num_classes

  def func_body_augmented(iteration, chosen_ids):
    # find a new facility location to add
    #  based on the clustering score and the NMI score
    candidate_ids = array_ops.setdiff1d(all_ids, chosen_ids)[0]
    new_chosen_idx = _find_loss_augmented_facility_idx(pairwise_distances,
                                                       labels, chosen_ids,
                                                       candidate_ids,
                                                       margin_multiplier,
                                                       margin_type)
    chosen_ids = array_ops.concat([chosen_ids, [new_chosen_idx]], 0)
    return iteration + 1, chosen_ids

  num_classes = array_ops.size(array_ops.unique(labels)[0])
  chosen_ids = array_ops.constant(0, dtype=dtypes.int32, shape=[0])

  # num_classes get determined at run time based on the sampled batch.
  iteration = array_ops.constant(0)

  _, chosen_ids = control_flow_ops.while_loop(
      func_cond_augmented,
      func_body_augmented, [iteration, chosen_ids],
      shape_invariants=[iteration.get_shape(), tensor_shape.TensorShape(
          [None])])
  return chosen_ids
def compute_gt_cluster_score(pairwise_distances, labels):
  """Compute ground truth facility location score.

  Loop over each unique classes and compute average travel distances.

  Args:
    pairwise_distances: 2-D Tensor of pairwise distances.
    labels: 1-D Tensor of ground truth cluster assignment.

  Returns:
    gt_cluster_score: dtypes.float32 score.
  """
  unique_class_ids = array_ops.unique(labels)[0]
  num_classes = array_ops.size(unique_class_ids)
  iteration = array_ops.constant(0)
  gt_cluster_score = array_ops.constant(0.0, dtype=dtypes.float32)

  def func_cond(iteration, gt_cluster_score):
    del gt_cluster_score  # Unused argument.
    return iteration < num_classes

  def func_body(iteration, gt_cluster_score):
    """Per each cluster, compute the average travel distance."""
    mask = math_ops.equal(labels, unique_class_ids[iteration])
    this_cluster_ids = array_ops.where(mask)
    pairwise_distances_subset = array_ops.transpose(
        array_ops.gather(
            array_ops.transpose(
                array_ops.gather(pairwise_distances, this_cluster_ids)),
            this_cluster_ids))
    this_cluster_score = -1.0 * math_ops.reduce_min(
        math_ops.reduce_sum(
            pairwise_distances_subset, axis=0))
    return iteration + 1, gt_cluster_score + this_cluster_score

  _, gt_cluster_score = control_flow_ops.while_loop(
      func_cond, func_body, [iteration, gt_cluster_score])
  return gt_cluster_score
 def fn(x):
   return array_ops.unique(x).y
Exemple #19
0
def hashed_embedding_lookup_sparse(params,
                                   sparse_values,
                                   dimension,
                                   combiner=None,
                                   default_value=None,
                                   name=None,
                                   hash_key=None):
    """Looks up embeddings of a sparse feature using parameter hashing.

  See `tf.contrib.layers.hashed_embedding_lookup` for embedding with hashing.

  Args:
    params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`.
      Each tensor must be of rank 1 with fully-defined shape.
    sparse_values: A 2-D `SparseTensor` containing the values to be embedded.
      Some rows may be empty.
    dimension: Embedding dimension
    combiner: A string specifying how to combine embedding results for each
        entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
        the default.
    default_value: The value to use for an entry with no features.
    name: An optional name for this op.
    hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
      function to combine the crosses fingerprints on SparseFeatureCrossOp
      (optional).

  Returns:
     Dense tensor with shape [N, dimension] with N the number of rows in
       sparse_values.

  Raises:
    TypeError: If sparse_values is not a SparseTensor.
    ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
  """
    if combiner is None:
        logging.warn("The default value of combiner will change from \"mean\" "
                     "to \"sqrtn\" after 2016/11/01.")
        combiner = "mean"
    if isinstance(params, variables.PartitionedVariable):
        params = list(params)
    if not isinstance(params, list):
        params = [params]
    if not isinstance(sparse_values, ops.SparseTensor):
        raise TypeError("sparse_values must be SparseTensor")

    with ops.name_scope(name, "hashed_sparse_embedding_lookup",
                        params + [sparse_values]) as scope:
        # Fill in the empty rows.
        if default_value is None:
            # Random default values to reduce the risk of collision.
            if sparse_values.dtype == dtypes.string:
                default_value = "6ZxWzWOHxZ"
            else:
                default_value = 1288896567
        sparse_values, _ = sparse_ops.sparse_fill_empty_rows(
            sparse_values, default_value)

        segment_ids = sparse_values.indices[:, 0]
        if segment_ids.dtype != dtypes.int32:
            segment_ids = math_ops.cast(segment_ids, dtypes.int32)

        values = sparse_values.values
        values, idx = array_ops.unique(values)

        embeddings = hashed_embedding_lookup(params,
                                             values,
                                             dimension,
                                             hash_key=hash_key)

        if combiner == "sum":
            embeddings = math_ops.sparse_segment_sum(embeddings,
                                                     idx,
                                                     segment_ids,
                                                     name=scope)
        elif combiner == "mean":
            embeddings = math_ops.sparse_segment_mean(embeddings,
                                                      idx,
                                                      segment_ids,
                                                      name=scope)
        elif combiner == "sqrtn":
            embeddings = math_ops.sparse_segment_sqrt_n(embeddings,
                                                        idx,
                                                        segment_ids,
                                                        name=scope)
        else:
            raise ValueError(
                "Combiner must be one of 'mean', 'sqrtn' or 'sum'.")

        return embeddings
Exemple #20
0
def connected_components(images):
  """Labels the connected components in a batch of images.

  A component is a set of pixels in a single input image, which are all adjacent
  and all have the same non-zero value. The components using a squared
  connectivity of one (all True entries are joined with their neighbors above,
  below, left, and right). Components across all images have consecutive ids 1
  through n. Components are labeled according to the first pixel of the
  component appearing in row-major order (lexicographic order by
  image_index_in_batch, row, col). Zero entries all have an output id of 0.

  This op is equivalent with `scipy.ndimage.measurements.label` on a 2D array
  with the default structuring element (which is the connectivity used here).

  Args:
    images: A 2D (H, W) or 3D (N, H, W) Tensor of boolean image(s).

  Returns:
    Components with the same shape as `images`. False entries in `images` have
    value 0, and all True entries map to a component id > 0.

  Raises:
    TypeError: if `images` is not 2D or 3D.
  """
  with ops.name_scope("connected_components"):
    image_or_images = ops.convert_to_tensor(images, name="images")
    if len(image_or_images.get_shape()) == 2:
      images = image_or_images[None, :, :]
    elif len(image_or_images.get_shape()) == 3:
      images = image_or_images
    else:
      raise TypeError(
          "images should have rank 2 (HW) or 3 (NHW). Static shape is %s" %
          image_or_images.get_shape())
    components = gen_image_ops.image_connected_components(images)

    # TODO(ringwalt): Component id renaming should be done in the op, to avoid
    # constructing multiple additional large tensors.
    components_flat = array_ops.reshape(components, [-1])
    unique_ids, id_index = array_ops.unique(components_flat)
    id_is_zero = array_ops.where(math_ops.equal(unique_ids, 0))[:, 0]
    # Map each nonzero id to consecutive values.
    nonzero_consecutive_ids = math_ops.range(
        array_ops.shape(unique_ids)[0] - array_ops.shape(id_is_zero)[0]) + 1

    def no_zero():
      # No need to insert a zero into the ids.
      return nonzero_consecutive_ids

    def has_zero():
      # Insert a zero in the consecutive ids where zero appears in unique_ids.
      # id_is_zero has length 1.
      zero_id_ind = math_ops.to_int32(id_is_zero[0])
      ids_before = nonzero_consecutive_ids[:zero_id_ind]
      ids_after = nonzero_consecutive_ids[zero_id_ind:]
      return array_ops.concat([ids_before, [0], ids_after], axis=0)

    new_ids = control_flow_ops.cond(
        math_ops.equal(array_ops.shape(id_is_zero)[0], 0), no_zero, has_zero)
    components = array_ops.reshape(
        array_ops.gather(new_ids, id_index), array_ops.shape(components))
    if len(image_or_images.get_shape()) == 2:
      return components[0, :, :]
    else:
      return components
Exemple #21
0
  def minimize(self, global_step=None, name=None):
    """Add operations to train a linear model by minimizing the loss function.

    Args:
      global_step: Optional `Variable` to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.

    Returns:
      An Operation that updates the variables passed in the constructor.
    """
    # Technically, the op depends on a lot more than the variables,
    # but we'll keep the list short.
    with name_scope(name, 'sdca/minimize'):
      sparse_example_indices = []
      sparse_feature_indices = []
      sparse_features_values = []
      for sf in self._examples['sparse_features']:
        sparse_example_indices.append(sf.example_indices)
        sparse_feature_indices.append(sf.feature_indices)
        # If feature values are missing, sdca assumes a value of 1.0f.
        if sf.feature_values is not None:
          sparse_features_values.append(sf.feature_values)

      # pylint: disable=protected-access
      example_ids_hashed = gen_sdca_ops.sdca_fprint(
          internal_convert_to_tensor(self._examples['example_ids']))
      # pylint: enable=protected-access
      example_state_data = self._hashtable.lookup(example_ids_hashed)
      # Solver returns example_state_update, new delta sparse_feature_weights
      # and delta dense_feature_weights.

      sparse_weights = []
      sparse_indices = []
      # If we have partitioned variables, keep a few dictionaries of Tensors
      # around that we need for the assign_add after the op call to
      # gen_sdca_ops.sdca_optimizer().  These are keyed because we may have a
      # mix of partitioned and un-partitioned variables.
      num_partitions_by_var = {}
      p_assignments_by_var = {}
      gather_ids_by_var = {}
      for v_num, (w, i) in enumerate(
          zip(self._slots['unshrinked_sparse_features_weights'],
              sparse_feature_indices)):
        # Append the sparse_indices (in full-variable space).
        sparse_idx = math_ops.cast(
            array_ops.unique(math_ops.cast(i, dtypes.int32))[0],
            dtypes.int64)
        sparse_indices.append(sparse_idx)
        if isinstance(w, list) or isinstance(w, var_ops.PartitionedVariable):
          num_partitions = len(w)
          flat_ids = array_ops.reshape(sparse_idx, [-1])
          # We use div partitioning, which is easiest to support downstream.
          # Compute num_total_ids as the sum of dim-0 of w, then assign
          # to partitions based on a constant number of ids per partition.
          # Optimize if we already know the full shape statically.
          dim_0_size = self._get_first_dimension_size_statically(
              w, num_partitions)

          if tensor_shape.dimension_value(dim_0_size):
            num_total_ids = constant_op.constant(
                tensor_shape.dimension_value(dim_0_size),
                flat_ids.dtype)
          else:
            dim_0_sizes = []
            for p in range(num_partitions):
              if tensor_shape.dimension_value(w[p].shape[0]) is not None:
                dim_0_sizes.append(tensor_shape.dimension_value(w[p].shape[0]))
              else:
                with ops.colocate_with(w[p]):
                  dim_0_sizes.append(array_ops.shape(w[p])[0])
            num_total_ids = math_ops.reduce_sum(
                math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
          ids_per_partition = num_total_ids // num_partitions
          extras = num_total_ids % num_partitions

          p_assignments = math_ops.maximum(
              flat_ids // (ids_per_partition + 1),
              (flat_ids - extras) // ids_per_partition)

          # Emulate a conditional using a boolean indicator tensor
          new_ids = array_ops.where(p_assignments < extras,
                                    flat_ids % (ids_per_partition + 1),
                                    (flat_ids - extras) % ids_per_partition)

          # Cast partition assignments to int32 for use in dynamic_partition.
          # There really should not be more than 2^32 partitions.
          p_assignments = math_ops.cast(p_assignments, dtypes.int32)
          # Partition list of ids based on assignments into num_partitions
          # separate lists.
          gather_ids = data_flow_ops.dynamic_partition(new_ids,
                                                       p_assignments,
                                                       num_partitions)
          # Add these into the dictionaries for use in the later update.
          num_partitions_by_var[v_num] = num_partitions
          p_assignments_by_var[v_num] = p_assignments
          gather_ids_by_var[v_num] = gather_ids

          # Gather the weights from each partition.
          partition_gathered_weights = []
          for p in range(num_partitions):
            with ops.colocate_with(w[p]):
              partition_gathered_weights.append(
                  array_ops.gather(w[p], gather_ids[p]))

          # Stitch the weights back together in the same order they were before
          # we dynamic_partitioned them.
          condition_indices = data_flow_ops.dynamic_partition(
              math_ops.range(array_ops.shape(new_ids)[0]),
              p_assignments, num_partitions)
          batch_gathered_weights = data_flow_ops.dynamic_stitch(
              condition_indices, partition_gathered_weights)
        else:
          w_as_tensor = internal_convert_to_tensor(w)
          with ops.device(w_as_tensor.device):
            batch_gathered_weights = array_ops.gather(
                w_as_tensor, sparse_idx)
        sparse_weights.append(batch_gathered_weights)

      # pylint: disable=protected-access
      if compat.forward_compatible(year=2018, month=10, day=30):
        esu, sfw, dfw = gen_sdca_ops.sdca_optimizer_v2(
            sparse_example_indices,
            sparse_feature_indices,
            sparse_features_values,
            self._convert_n_to_tensor(self._examples['dense_features']),
            internal_convert_to_tensor(self._examples['example_weights']),
            internal_convert_to_tensor(self._examples['example_labels']),
            sparse_indices,
            sparse_weights,
            self._convert_n_to_tensor(self._slots[
                'unshrinked_dense_features_weights']),
            example_state_data,
            loss_type=self._options['loss_type'],
            l1=self._options['symmetric_l1_regularization'],
            l2=self._symmetric_l2_regularization(),
            num_loss_partitions=self._num_loss_partitions(),
            num_inner_iterations=1,
            adaptive=self._adaptive())
      else:
        esu, sfw, dfw = gen_sdca_ops.sdca_optimizer(
            sparse_example_indices,
            sparse_feature_indices,
            sparse_features_values,
            self._convert_n_to_tensor(self._examples['dense_features']),
            internal_convert_to_tensor(self._examples['example_weights']),
            internal_convert_to_tensor(self._examples['example_labels']),
            sparse_indices,
            sparse_weights,
            self._convert_n_to_tensor(self._slots[
                'unshrinked_dense_features_weights']),
            example_state_data,
            loss_type=self._options['loss_type'],
            l1=self._options['symmetric_l1_regularization'],
            l2=self._symmetric_l2_regularization(),
            num_loss_partitions=self._num_loss_partitions(),
            num_inner_iterations=1,
            adaptative=self._adaptive())
      # pylint: enable=protected-access

      with ops.control_dependencies([esu]):
        update_ops = [self._hashtable.insert(example_ids_hashed, esu)]
        # Update the weights before the proximal step.
        for v_num, (w, i, u) in enumerate(
            zip(self._slots['unshrinked_sparse_features_weights'],
                sparse_indices, sfw)):
          if (isinstance(w, var_ops.PartitionedVariable) or
              isinstance(w, list)):
            update_ops += self._get_partitioned_update_ops(
                v_num, num_partitions_by_var, p_assignments_by_var,
                gather_ids_by_var, w, u, p_assignments, num_partitions)
          else:
            update_ops.append(state_ops.scatter_add(w, i, u))
        for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw):
          if (isinstance(w, var_ops.PartitionedVariable) or
              isinstance(w, list)):
            split_updates = array_ops.split(
                u, num_or_size_splits=[v.shape.as_list()[0] for v in w])
            for v, split_update in zip(w, split_updates):
              update_ops.append(state_ops.assign_add(v, split_update))
          else:
            update_ops.append(state_ops.assign_add(w, u))
      if not global_step:
        return control_flow_ops.group(*update_ops)
      with ops.control_dependencies(update_ops):
        return state_ops.assign_add(global_step, 1, name=name).op
 def fn(x):
   return array_ops.unique(x).y  # COMMENT4
def embedding_lookup_sparse(params, sp_ids, sp_weights,
                            name=None,
                            combiner="mean"):
  """Computes embeddings for the given ids and weights.

  This op assumes that there is at least one id for each row in the dense tensor
  represented by sp_ids (i.e. there are no rows with empty features), and that
  all the indices of sp_ids are in canonical row-major order.

  It also assumes that all id values lie in the range [0, p0), where p0
  is the sum of the size of params along dimension 0.

  Args:
    params: A single tensor representing the complete embedding tensor,
      or a list of P tensors all of same shape except for the first dimension,
      representing sharded embedding tensors. In the latter case, the ids are
      partitioned by id % P, and we do separate lookups in params[p] for
      0 <= p < P, and then stitch the results back together into a single
      result tensor. The first dimension is allowed to vary as the vocab
      size is not necessarily a multiple of P.
    sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
      where N is typically batch size and M is arbitrary.
    sp_weights: either a SparseTensor of float / double weights, or None to
      indicate all weights should be taken to be 1. If specified, sp_weights
      must have exactly the same shape and indices as sp_ids.
    name: Optional name for the op.
    combiner: A string specifying the reduction op. Currently "mean" and "sum"
      are supported.
      "sum" computes the weighted sum of the embedding results for each row.
      "mean" is the weighted sum divided by the total weight.

  Returns:
    A dense tensor representing the combined embeddings for the
    sparse ids. For each row in the dense tensor represented by sp_ids, the op
    looks up the embeddings for all ids in that row, multiplies them by the
    corresponding weight, and combines these embeddings as specified.

    In other words, if
      shape(combined params) = [p0, p1, ..., pm]
    and
      shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]
    then
      shape(output) = [d0, d1, ..., dn-1, p1, ..., pm].

    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are

      [0, 0]: id 1, weight 2.0
      [0, 1]: id 3, weight 0.5
      [1, 0]: id 0, weight 1.0
      [2, 3]: id 1, weight 3.0

    with combiner="mean", then the output will be a 3x20 matrix where
      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
      output[1, :] = params[0, :] * 1.0
      output[2, :] = params[1, :] * 3.0

  Raises:
    TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
      None nor SparseTensor.
    ValueError: If combiner is not one of {"mean", "sum"}.
  """
  if combiner not in ("mean", "sum"):
    raise ValueError("combiner must be one of 'mean' or 'sum'")
  if not isinstance(params, list):
    params = [params]
  if not isinstance(sp_ids, ops.SparseTensor):
    raise TypeError("sp_ids must be SparseTensor")
  ignore_weights = sp_weights is None
  if not ignore_weights and not isinstance(sp_weights, ops.SparseTensor):
    raise TypeError("sp_weights must be either None or SparseTensor")

  with ops.op_scope(params + [sp_ids], name, "embedding_lookup_sparse") as name:
    segment_ids = sp_ids.indices[:, 0]
    if segment_ids.dtype != types.int32:
      segment_ids = math_ops.cast(segment_ids, types.int32)

    ids = sp_ids.values
    if ignore_weights:
      ids, idx = array_ops.unique(ids)
    else:
      idx = None

    embeddings = embedding_lookup(params, ids)
    if not ignore_weights:
      weights = sp_weights.values
      if weights.dtype != embeddings.dtype:
        weights = math_ops.cast(weights, embeddings.dtype)

      # Reshape weights to allow broadcast
      ones = array_ops.fill(
          array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
      bcast_weights_shape = array_ops.concat(0, [
          array_ops.shape(weights), ones])
      weights = array_ops.reshape(weights, bcast_weights_shape)
      embeddings *= weights

      if combiner == "sum":
        embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
      elif combiner == "mean":
        embeddings = math_ops.segment_sum(embeddings, segment_ids)
        weight_sum = math_ops.segment_sum(weights, segment_ids)
        embeddings = math_ops.div(embeddings, weight_sum, name=name)
      else:
        assert False, "Unrecognized combiner"
    else:
      assert idx is not None
      if combiner == "sum":
        embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
                                                 name=name)
      elif combiner == "mean":
        embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids,
                                                  name=name)
      else:
        assert False, "Unrecognized combiner"

    return embeddings
def embedding_lookup_sparse(params, sp_ids, sp_weights,
                            partition_strategy="mod",
                            name=None,
                            combiner=None,
                            max_norm=None):
  """Computes embeddings for the given ids and weights.

  This op assumes that there is at least one id for each row in the dense tensor
  represented by sp_ids (i.e. there are no rows with empty features), and that
  all the indices of sp_ids are in canonical row-major order.

  It also assumes that all id values lie in the range [0, p0), where p0
  is the sum of the size of params along dimension 0.

  Args:
    params: A single tensor representing the complete embedding tensor,
      or a list of P tensors all of same shape except for the first dimension,
      representing sharded embedding tensors.  Alternatively, a
      `PartitionedVariable`, created by partitioning along dimension 0.
    sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
      where N is typically batch size and M is arbitrary.
    sp_weights: either a SparseTensor of float / double weights, or None to
      indicate all weights should be taken to be 1. If specified, sp_weights
      must have exactly the same shape and indices as sp_ids.
    partition_strategy: A string specifying the partitioning strategy, relevant
      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
      is `"mod"`. See `tf.nn.embedding_lookup` for more details.
    name: Optional name for the op.
    combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
      and "sum" are supported.
      "sum" computes the weighted sum of the embedding results for each row.
      "mean" is the weighted sum divided by the total weight.
      "sqrtn" is the weighted sum divided by the square root of the sum of the
      squares of the weights.
    max_norm: If not None, each embedding is normalized to have l2 norm equal
      to max_norm before combining.

  Returns:
    A dense tensor representing the combined embeddings for the
    sparse ids. For each row in the dense tensor represented by sp_ids, the op
    looks up the embeddings for all ids in that row, multiplies them by the
    corresponding weight, and combines these embeddings as specified.

    In other words, if

      shape(combined params) = [p0, p1, ..., pm]

    and

      shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]

    then

      shape(output) = [d0, d1, ..., dn-1, p1, ..., pm].

    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are

      [0, 0]: id 1, weight 2.0
      [0, 1]: id 3, weight 0.5
      [1, 0]: id 0, weight 1.0
      [2, 3]: id 1, weight 3.0

    with `combiner`="mean", then the output will be a 3x20 matrix where

      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
      output[1, :] = params[0, :] * 1.0
      output[2, :] = params[1, :] * 3.0

  Raises:
    TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
      None nor SparseTensor.
    ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
  """
  if combiner is None:
    logging.warn("The default value of combiner will change from \"mean\" "
                 "to \"sqrtn\" after 2016/11/01.")
    combiner = "mean"
  if combiner not in ("mean", "sqrtn", "sum"):
    raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
  if isinstance(params, variables.PartitionedVariable):
    params = list(params)  # Iterate to get the underlying Variables.
  if not isinstance(params, list):
    params = [params]
  if not isinstance(sp_ids, sparse_tensor.SparseTensor):
    raise TypeError("sp_ids must be SparseTensor")
  ignore_weights = sp_weights is None
  if not ignore_weights:
    if not isinstance(sp_weights, sparse_tensor.SparseTensor):
      raise TypeError("sp_weights must be either None or SparseTensor")
    sp_ids.values.get_shape().assert_is_compatible_with(
        sp_weights.values.get_shape())
    sp_ids.indices.get_shape().assert_is_compatible_with(
        sp_weights.indices.get_shape())
    sp_ids.shape.get_shape().assert_is_compatible_with(
        sp_weights.shape.get_shape())
    # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
    # sp_weights have equal indices and shapes.

  with ops.name_scope(name, "embedding_lookup_sparse",
                      params + [sp_ids]) as name:
    segment_ids = sp_ids.indices[:, 0]
    if segment_ids.dtype != dtypes.int32:
      segment_ids = math_ops.cast(segment_ids, dtypes.int32)

    ids = sp_ids.values
    if ignore_weights:
      ids, idx = array_ops.unique(ids)
    else:
      idx = None

    embeddings = embedding_lookup(
        params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
    if not ignore_weights:
      weights = sp_weights.values
      if weights.dtype != embeddings.dtype:
        weights = math_ops.cast(weights, embeddings.dtype)

      # Reshape weights to allow broadcast
      ones = array_ops.fill(
          array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
      bcast_weights_shape = array_ops.concat(0, [
          array_ops.shape(weights), ones])

      orig_weights_shape = weights.get_shape()
      weights = array_ops.reshape(weights, bcast_weights_shape)

      # Set the weight shape, since after reshaping to bcast_weights_shape,
      # the shape becomes None.
      if embeddings.get_shape().ndims is not None:
        weights.set_shape(orig_weights_shape.concatenate(
            [1 for _ in range(embeddings.get_shape().ndims - 1)]))

      embeddings *= weights

      if combiner == "sum":
        embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
      elif combiner == "mean":
        embeddings = math_ops.segment_sum(embeddings, segment_ids)
        weight_sum = math_ops.segment_sum(weights, segment_ids)
        embeddings = math_ops.div(embeddings, weight_sum, name=name)
      elif combiner == "sqrtn":
        embeddings = math_ops.segment_sum(embeddings, segment_ids)
        weights_squared = math_ops.pow(weights, 2)
        weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
        weight_sum_sqrt = math_ops.sqrt(weight_sum)
        embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name)
      else:
        assert False, "Unrecognized combiner"
    else:
      assert idx is not None
      if combiner == "sum":
        embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
                                                 name=name)
      elif combiner == "mean":
        embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids,
                                                  name=name)
      elif combiner == "sqrtn":
        embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx,
                                                    segment_ids, name=name)
      else:
        assert False, "Unrecognized combiner"

    return embeddings
Exemple #25
0
def hashed_embedding_lookup_sparse(params,
                                   sparse_values,
                                   dimension,
                                   combiner="mean",
                                   default_value=None,
                                   name=None):
  """Looks up embeddings of a sparse feature using parameter hashing.

  See `tf.contrib.layers.hashed_embedding_lookup` for embedding with hashing.

  Args:
    params: A `Tensor` or `list` of `Tensors`.
      Each tensor must be of rank 1 with fully-defined shape.
    sparse_values: A 2-D `SparseTensor` containing the values to be embedded.
      Some rows may be empty.
    dimension: Embedding dimension
    combiner: A string specifying how to combine embedding results for each
        entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
        the default.
    default_value: The value to use for an entry with no features.
    name: An optional name for this op.

  Returns:
     Dense tensor with shape [N, dimension] with N the number of rows in
       sparse_values.

  Raises:
    TypeError: If sparse_values is not a SparseTensor.
    ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
  """

  if not isinstance(params, list):
    params = [params]
  if not isinstance(sparse_values, ops.SparseTensor):
    raise TypeError("sparse_values must be SparseTensor")

  with ops.name_scope(name, "hashed_sparse_embedding_lookup",
                      params + [sparse_values]) as scope:
    # Fill in the empty rows.
    if default_value is None:
      # Random default values to reduce the risk of collision.
      if sparse_values.dtype == dtypes.string:
        default_value = "6ZxWzWOHxZ"
      else:
        default_value = 1288896567
    sparse_values, _ = sparse_ops.sparse_fill_empty_rows(
        sparse_values, default_value)

    segment_ids = sparse_values.indices[:, 0]
    if segment_ids.dtype != dtypes.int32:
      segment_ids = math_ops.cast(segment_ids, dtypes.int32)

    values = sparse_values.values
    values, idx = array_ops.unique(values)

    embeddings = hashed_embedding_lookup(params, values, dimension)

    if combiner == "sum":
      embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
                                               name=scope)
    elif combiner == "mean":
      embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids,
                                                name=scope)
    elif combiner == "sqrtn":
      embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, segment_ids,
                                                  name=scope)
    else:
      raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.")

    return embeddings
Exemple #26
0
 def _unique(x):
     u = array_ops.unique(x)
     y = array_ops.pad(u.y,
                       [[0, _get_dim(u.idx, 0) - _get_dim(u.y, 0)]])
     y = math_ops.cast(y, dtypes.int64)
     return [y, u.idx]
Exemple #27
0
 def unique_sum(xs):
     """Sum over the unique values, for testing."""
     unique_xs, indices = array_ops.unique(xs)
     return math_ops.reduce_sum(unique_xs), indices
def _embedding_lookup_with_distributed_aggregation(params,
                                                   ids,
                                                   partition_strategy="mod",
                                                   name=None,
                                                   max_norm=None,
                                                   weights=None,
                                                   idx=None,
                                                   segment_ids=None):
  """Lookup helper for embedding_lookup_sparse_with_distributed_aggregation."""
  if params is None or params == []:  # pylint: disable=g-explicit-bool-comparison
    raise ValueError("Need at least one param")
  if isinstance(params, variables.PartitionedVariable):
    params = list(params)  # Iterate to get the underlying Variables.
  if not isinstance(params, list):
    params = [params]

  def maybe_normalize(x):
    if max_norm is not None:
      if x.get_shape().ndims is not None:
        ndims = x.get_shape().ndims
      else:
        ndims = array_ops.size(array_ops.shape(x))
      return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
    return x

  with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation",
                      params + [ids]) as name:
    np = len(params)  # Number of partitions
    # Preserve the resource variable status to avoid accidental dense reads.
    if not any(
        isinstance(p, resource_variable_ops.ResourceVariable) for p in params):
      params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
    if np == 1:
      with ops.colocate_with(params[0]):
        ret = maybe_normalize(_do_gather(params[0], ids))
        ignore_weights = weights is None
        if not ignore_weights:
          if weights.dtype != ret.dtype:
            weights = math_ops.cast(weights, ret.dtype)
          # Reshape to allow broadcast
          ones = array_ops.fill(
              array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1)
          bcast_weights_shape = array_ops.concat(
              [array_ops.shape(weights), ones], 0)
          orig_weights_shape = weights.get_shape()
          weights = array_ops.reshape(weights, bcast_weights_shape)
          # Set weights shape after reshape
          if ret.get_shape().ndims is not None:
            weights.set_shape(
                orig_weights_shape.concatenate(
                    [1 for _ in range(ret.get_shape().ndims - 1)]))
          ret *= weights
          return math_ops.segment_sum(ret, segment_ids, name=name)
        else:
          return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name)
    else:
      ids = ops.convert_to_tensor(ids, name="ids")
      flat_ids = array_ops.reshape(ids, [-1])
      original_indices = math_ops.range(array_ops.size(flat_ids))

      # Create p_assignments and set new_ids depending on the strategy.
      if partition_strategy == "mod":
        p_assignments = flat_ids % np
        new_ids = flat_ids // np
      elif partition_strategy == "div":
        # Compute num_total_ids as the sum of dim-0 of params, then assign to
        # partitions based on a constant number of ids per partition. Optimize
        # if we already know the full shape statically.
        dim_0_size = params[0].get_shape()[0]
        for p in xrange(1, np):
          dim_0_size += params[p].get_shape()[0]
        if dim_0_size.value:
          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
        else:
          dim_0_sizes = []
          for p in xrange(np):
            if params[p].get_shape()[0].value is not None:
              dim_0_sizes.append(params[p].get_shape()[0].value)
            else:
              with ops.colocate_with(params[p]):
                dim_0_sizes.append(array_ops.shape(params[p])[0])
          num_total_ids = math_ops.reduce_sum(
              math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
        ids_per_partition = num_total_ids // np
        extras = num_total_ids % np

        p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), (
            flat_ids - extras) // ids_per_partition)

        # Emulate a conditional using a boolean indicator tensor
        is_in_first_extras_partitions = math_ops.cast(p_assignments < extras,
                                                      flat_ids.dtype)
        new_ids = (is_in_first_extras_partitions * (flat_ids %
                                                    (ids_per_partition + 1)) +
                   (1 - is_in_first_extras_partitions) * (
                       (flat_ids - extras) % ids_per_partition))
      else:
        raise ValueError("Unrecognized partition strategy: " +
                         partition_strategy)

      # Cast partition assignments to int32 for use in dynamic_partition.
      # There really should not be more than 2^32 partitions.
      p_assignments = math_ops.cast(p_assignments, dtypes.int32)
      # Partition list of ids based on assignments into np separate lists
      gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
      # Similarly, partition the original indices.
      pindices = data_flow_ops.dynamic_partition(original_indices,
                                                 p_assignments, np)
      # Do np separate lookups, finding embeddings for plist[p] in params[p]
      partitioned_result = []
      for p in xrange(np):
        with ops.colocate_with(params[p]):
          partitioned_result.append(_do_gather(params[p], gather_ids[p]))

      ignore_weights = weights is None
      if not ignore_weights:
        # Partition weights according to pindices.
        partitioned_weight = []
        for p in xrange(np):
          partitioned_weight.append(array_ops.gather(weights, pindices[p]))
      # Reshape each partition result.
      element_shape = params[0].get_shape()[1:]
      for p in params[1:]:
        element_shape = element_shape.merge_with(p.get_shape()[1:])
      if element_shape.is_fully_defined():
        for p in xrange(np):
          with ops.colocate_with(params[p]):
            partitioned_result[p] = array_ops.reshape(
                partitioned_result[p],
                array_ops.concat([array_ops.shape(pindices[p]), element_shape],
                                 0))
      else:
        with ops.colocate_with(params[0]):
          params_shape = array_ops.shape(params[0])
        for p in xrange(np):
          with ops.colocate_with(params[p]):
            partitioned_result[p] = array_ops.reshape(
                partitioned_result[p],
                array_ops.concat([
                    array_ops.shape(pindices[p]), array_ops.slice(
                        params_shape, [1], [-1])
                ], 0))
      # Normalize each partition result.
      for p in xrange(np):
        with ops.colocate_with(params[p]):
          partitioned_result[p] = maybe_normalize(partitioned_result[p])
      if not ignore_weights:
        # Multiply each partition result with partition weights.
        for p in xrange(np):
          with ops.colocate_with(params[p]):
            if partitioned_weight[p].dtype != partitioned_result[p].dtype:
              partitioned_weight[p] = math_ops.cast(partitioned_weight[p],
                                                    partitioned_result[p].dtype)
            # Reshape partition weights.
            ones = array_ops.fill(
                array_ops.expand_dims(
                    array_ops.rank(partitioned_result[p]) - 1, 0), 1)
            bcast_weights_shape = array_ops.concat(
                [array_ops.shape(partitioned_weight[p]), ones], 0)
            orig_weights_shape = partitioned_weight[p].get_shape()
            partitioned_weight[p] = array_ops.reshape(partitioned_weight[p],
                                                      bcast_weights_shape)
            if partitioned_result[p].get_shape().ndims is not None:
              partitioned_weight[p].set_shape(
                  orig_weights_shape.concatenate([
                      1
                      for _ in range(partitioned_result[p].get_shape().ndims -
                                     1)
                  ]))
            partitioned_result[p] *= partitioned_weight[p]
      partitioned_segment_ids = []
      for p in xrange(np):
        if not ignore_weights:
          # Partition segment_ids according to pindices.
          p_segment_ids = array_ops.gather(segment_ids, pindices[p])
          # Number the p_segment_ids to meet segment_sum's requirements. Note
          # that unique_p_segment_ids contains unique segment ids of this
          # partition and these ids' order is unchanged.
          unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
              p_segment_ids)
          partitioned_segment_ids.append(unique_p_segment_ids)
          # segment_sum this partition's result.
          with ops.colocate_with(params[p]):
            partitioned_result[p] = math_ops.segment_sum(
                partitioned_result[p], unique_p_segment_idx)
        else:
          # When ignore weights, we need to get indexs of elements in idx and
          # segment_ids.
          _, exclude_idx = array_ops.setdiff1d(idx, pindices[p])
          all_idx = math_ops.range(array_ops.shape(idx)[0])
          _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx)
          # Gather segment_ids and idx according to indexs.
          p_segment_ids = array_ops.gather(segment_ids, include_idx)
          p_idx = array_ops.gather(idx, include_idx)
          # Number the p_segment_ids, same as ignore_weights case above.
          unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
              p_segment_ids)
          _, unique_p_idx_idx = array_ops.unique(p_idx)
          partitioned_segment_ids.append(unique_p_segment_ids)
          with ops.colocate_with(params[p]):
            partitioned_result[p] = math_ops.sparse_segment_sum(
                partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx)
      # Concat each partition's segment_ids and result for final segment_sum.
      concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0)
      concat_partitioned_result = array_ops.concat(partitioned_result, 0)
      return math_ops.unsorted_segment_sum(
          concat_partitioned_result,
          concat_segment_ids,
          math_ops.reduce_max(concat_segment_ids) + 1,
          name=name)
 def log_2plus_unique_x(x):
     op_callbacks.add_op_callback(instrument.callback)
     unique_values, _ = array_ops.unique(x)
     y = math_ops.log(2.0 + unique_values)
     op_callbacks.remove_op_callback(instrument.callback)
     return math_ops.sin(y)
  def testAllowsWatchingUnconnectedOutputTensor(self):
    """Watch an output slot not emitting any edges.

    (Not even control edges from the node.)
    """

    with session.Session() as sess:
      x_init = constant_op.constant([2, 2, 3, 5, 5])
      x = variables.Variable(x_init, name="unconnected/x")

      # The UniqueOp (tf.unique) has two output slots. Use only slot 0 in the
      # graph. Let the debugger watch the unused slot 1.
      unique_x, _ = array_ops.unique(x, name="unconnected/unique_x")
      y = math_ops.add(unique_x, [0, 1, 2], name="unconnected/y")

      x.initializer.run()

      # Verify that only slot 0 of unique_x has recipients, while slot 1 of the
      # same node does not have recipients.
      unique_x_slot_0_recipients = []
      unique_x_slot_1_recipients = []
      for op in sess.graph.get_operations():
        for inp in op.inputs:
          if inp.name == "unconnected/unique_x:0":
            unique_x_slot_0_recipients.append(op.name)
          elif inp.name == "unconnected/unique_x:1":
            unique_x_slot_1_recipients.append(op.name)

      self.assertEqual(["unconnected/y"], unique_x_slot_0_recipients)
      self.assertEqual([], unique_x_slot_1_recipients)

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(
          run_options,
          sess.graph,
          debug_ops=["DebugIdentity"],
          debug_urls=self._debug_urls())

      run_metadata = config_pb2.RunMetadata()
      result = sess.run(y, options=run_options, run_metadata=run_metadata)
      self.assertAllClose([2, 4, 7], result)

      dump = debug_data.DebugDumpDir(
          self._dump_root, partition_graphs=run_metadata.partition_graphs)

      # Assert that the connected slot (slot 0) is dumped properly.
      unique_x_slot_0_dumps = dump.watch_key_to_data(
          "unconnected/unique_x:0:DebugIdentity")
      self.assertEqual(1, len(unique_x_slot_0_dumps))
      self.assertEqual("unconnected/unique_x",
                       unique_x_slot_0_dumps[0].node_name)
      self.assertEqual(0, unique_x_slot_0_dumps[0].output_slot)
      self.assertAllClose([2, 3, 5], unique_x_slot_0_dumps[0].get_tensor())

      # Assert that the unconnected slot (slot 1) is dumped properly.
      unique_x_slot_1_dumps = dump.watch_key_to_data(
          "unconnected/unique_x:1:DebugIdentity")
      self.assertEqual(1, len(unique_x_slot_1_dumps))
      self.assertEqual("unconnected/unique_x",
                       unique_x_slot_1_dumps[0].node_name)
      self.assertEqual(1, unique_x_slot_1_dumps[0].output_slot)
      self.assertAllClose([0, 0, 1, 2, 2],
                          unique_x_slot_1_dumps[0].get_tensor())
Exemple #31
0
  def minimize(self, global_step=None, name=None):
    """Add operations to train a linear model by minimizing the loss function.

    Args:
      global_step: Optional `Variable` to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.

    Returns:
      An Operation that updates the variables passed in the constructor.
    """
    # Technically, the op depends on a lot more than the variables,
    # but we'll keep the list short.
    with name_scope(name, 'sdca/minimize'):
      sparse_example_indices = []
      sparse_feature_indices = []
      sparse_features_values = []
      for sf in self._examples['sparse_features']:
        sparse_example_indices.append(sf.example_indices)
        sparse_feature_indices.append(sf.feature_indices)
        # If feature values are missing, sdca assumes a value of 1.0f.
        if sf.feature_values is not None:
          sparse_features_values.append(sf.feature_values)

      example_ids_hashed = sdca_fprint(
          convert_to_tensor(self._examples['example_ids']))
      example_state_data = self._hashtable.lookup(example_ids_hashed)
      # Solver returns example_state_update, new delta sparse_feature_weights
      # and delta dense_feature_weights.

      weights_tensor = self._convert_n_to_tensor(self._slots[
          'unshrinked_sparse_features_weights'])
      sparse_weights = []
      sparse_indices = []
      for w, i in zip(weights_tensor, sparse_feature_indices):
        # Find the feature ids to lookup in the variables.
        with ops.device(w.device):
          sparse_indices.append(
              math_ops.cast(
                  array_ops.unique(math_ops.cast(i, dtypes.int32))[0],
                  dtypes.int64))
          sparse_weights.append(array_ops.gather(w, sparse_indices[-1]))

      esu, sfw, dfw = sdca_optimizer(
          sparse_example_indices,
          sparse_feature_indices,
          sparse_features_values,
          self._convert_n_to_tensor(self._examples['dense_features']),
          convert_to_tensor(self._examples['example_weights']),
          convert_to_tensor(self._examples['example_labels']),
          sparse_indices,
          sparse_weights,
          self._convert_n_to_tensor(self._slots[
              'unshrinked_dense_features_weights']),
          example_state_data,
          loss_type=self._options['loss_type'],
          l1=self._options['symmetric_l1_regularization'],
          l2=self._symmetric_l2_regularization(),
          num_loss_partitions=self._num_loss_partitions(),
          num_inner_iterations=1)

      with ops.control_dependencies([esu]):
        update_ops = [self._hashtable.insert(example_ids_hashed, esu)]
        # Update the weights before the proximal step.
        for w, i, u in zip(self._slots['unshrinked_sparse_features_weights'],
                           sparse_indices, sfw):
          update_ops.append(state_ops.scatter_add(w, i, u))
        for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw):
          update_ops.append(w.assign_add(u))

      if not global_step:
        return control_flow_ops.group(*update_ops)
      with ops.control_dependencies(update_ops):
        return state_ops.assign_add(global_step, 1, name=name).op
Exemple #32
0
    def _process_input_helper(self,
                              update_row_factors,
                              sp_input=None,
                              transpose_input=False,
                              row_weights=None):
        """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update or project the row_factors, else
        update or project the column factors.
      sp_input: Please refer to comments for update_row_factors,
        update_col_factors, project_row_factors, and project_col_factors for
        restrictions.
      transpose_input: If True, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.
      row_weights: If not None, this is the row/column weights to be used for
        the update or projection. If None, use the corresponding weights from
        the model. Note that the feature (column/row) weights will be
        determined by the model. When not None, it can either be a scalar or
        a rank-1 tensor with the same number of elements as the number of rows
        of columns to be updated/projected.

    Returns:
      A tuple consisting of the following elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
      unregularized_loss: A tensor (scalar) that contains the normalized
        minibatch loss corresponding to sp_input, without the regularization
        term. Add the regularization term below to yield the loss.
      regularization: A tensor (scalar) that contains the normalized
        regularization term for the minibatch loss corresponding to sp_input.
      sum_weights: The sum of the weights corresponding to sp_input. This
        can be used with unregularized loss to calculate the root weighted
        squared error.
    """
        assert isinstance(sp_input, sparse_tensor.SparseTensor)

        if update_row_factors:
            left = self._row_factors
            right_factors = self._col_factors_cache
            row_wt = self._row_wt_cache
            col_wt = self._col_wt_cache
            total_rows = self._input_rows
            total_cols = self._input_cols
            sharding_func = WALSModel._get_sharding_func(
                self._input_rows, self._num_row_shards)
            gramian = self._col_gramian_cache
        else:
            left = self._col_factors
            right_factors = self._row_factors_cache
            row_wt = self._col_wt_cache
            col_wt = self._row_wt_cache
            total_rows = self._input_cols
            total_cols = self._input_rows
            sharding_func = WALSModel._get_sharding_func(
                self._input_cols, self._num_col_shards)
            gramian = self._row_gramian_cache
            transpose_input = not transpose_input

        # Note that the row indices of sp_input are based on the original full input
        # Here we reindex the rows and give them contiguous ids starting at 0.
        # We use tf.unique to achieve this reindexing. Note that this is done so
        # that the downstream kernel can assume that the input is "dense" along the
        # row dimension.
        row_ids, col_ids = array_ops.split(value=sp_input.indices,
                                           num_or_size_splits=2,
                                           axis=1)
        update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0])
        update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0])
        col_ids = array_ops.expand_dims(
            math_ops.cast(all_col_ids, dtypes.int64), 1)
        row_ids = array_ops.expand_dims(
            math_ops.cast(all_row_ids, dtypes.int64), 1)

        if transpose_input:
            update_indices = update_col_indices
            row_shape = [
                math_ops.cast(
                    array_ops.shape(update_row_indices)[0], dtypes.int64)
            ]
            gather_indices = update_row_indices
        else:
            update_indices = update_row_indices
            row_shape = [
                math_ops.cast(
                    array_ops.shape(update_col_indices)[0], dtypes.int64)
            ]
            gather_indices = update_col_indices

        num_rows = math_ops.cast(
            array_ops.shape(update_indices)[0], dtypes.int64)
        col_shape = [num_rows]
        right = embedding_ops.embedding_lookup(right_factors,
                                               gather_indices,
                                               partition_strategy="div")
        new_sp_indices = array_ops.concat([row_ids, col_ids], 1)
        new_sp_shape = (array_ops.concat([row_shape, col_shape], 0)
                        if transpose_input else array_ops.concat(
                            [col_shape, row_shape], 0))
        new_sp_input = sparse_tensor.SparseTensor(indices=new_sp_indices,
                                                  values=sp_input.values,
                                                  dense_shape=new_sp_shape)

        # Compute lhs and rhs of the normal equations
        total_lhs = (self._unobserved_weight * gramian)
        if self._regularization_matrix is not None:
            total_lhs += self._regularization_matrix
        if self._row_weights is None:
            # Special case of ALS. Use a much simpler update rule.
            total_rhs = (self._unobserved_weight *
                         sparse_ops.sparse_tensor_dense_matmul(
                             new_sp_input, right, adjoint_a=transpose_input))
            # TODO(rmlarsen): handle transposing in tf.linalg.solve instead of
            # transposing explicitly.
            # TODO(rmlarsen): multi-thread tf.matrix_solve.
            new_left_values = array_ops.transpose(
                linalg_ops.matrix_solve(total_lhs,
                                        array_ops.transpose(total_rhs)))
        else:
            if row_weights is None:
                # TODO(yifanchen): Add special handling for single shard without using
                # embedding_lookup and perform benchmarks for those cases. Same for
                # col_weights lookup below.
                row_weights_slice = embedding_ops.embedding_lookup(
                    row_wt, update_indices, partition_strategy="div")
            else:
                num_indices = array_ops.shape(update_indices)[0]
                with ops.control_dependencies([
                        check_ops.assert_less_equal(
                            array_ops.rank(row_weights), 1)
                ]):
                    row_weights_slice = control_flow_ops.cond(
                        math_ops.equal(array_ops.rank(row_weights), 0), lambda:
                        (array_ops.ones([num_indices]) * row_weights),
                        lambda: math_ops.cast(row_weights, dtypes.float32))

            col_weights = embedding_ops.embedding_lookup(
                col_wt, gather_indices, partition_strategy="div")
            partial_lhs, total_rhs = (
                gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
                    right,
                    col_weights,
                    self._unobserved_weight,
                    row_weights_slice,
                    new_sp_input.indices,
                    new_sp_input.values, [],
                    num_rows,
                    transpose_input,
                    name="wals_compute_partial_lhs_rhs"))
            total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
            total_rhs = array_ops.expand_dims(total_rhs, -1)
            new_left_values = array_ops.squeeze(
                linalg_ops.matrix_solve(total_lhs, total_rhs), [2])

        update_op_name = "row_update" if update_row_factors else "col_update"
        update_op = self.scatter_update(left,
                                        update_indices,
                                        new_left_values,
                                        sharding_func,
                                        name=update_op_name)

        # Create the loss subgraph
        loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input)
                         if transpose_input else new_sp_input)
        # sp_approx is the low rank estimate of the input matrix, formed by
        # computing the product <\\(u_i, v_j\\)> for (i, j) in loss_sp_input.indices.
        sp_approx_vals = gen_factorization_ops.masked_matmul(
            new_left_values,
            right,
            loss_sp_input.indices,
            transpose_a=False,
            transpose_b=True)
        sp_approx = sparse_tensor.SparseTensor(loss_sp_input.indices,
                                               sp_approx_vals,
                                               loss_sp_input.dense_shape)
        sp_approx_sq = math_ops.square(sp_approx)
        sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1))
        sp_residual_sq = math_ops.square(sp_residual)
        row_wt_mat = (constant_op.constant(0.) if self._row_weights is None
                      else array_ops.expand_dims(row_weights_slice, 1))
        col_wt_mat = (constant_op.constant(0.) if self._col_weights is None
                      else array_ops.expand_dims(col_weights, 0))

        # We return the normalized loss
        partial_row_gramian = math_ops.matmul(new_left_values,
                                              new_left_values,
                                              transpose_a=True)
        normalization_factor = total_rows / math_ops.cast(
            num_rows, dtypes.float32)

        unregularized_loss = (
            self._unobserved_weight * (  # pyformat line break
                sparse_ops.sparse_reduce_sum(sp_residual_sq) -  # pyformat break
                sparse_ops.sparse_reduce_sum(sp_approx_sq) +  # pyformat break
                math_ops.trace(math_ops.matmul(partial_row_gramian, gramian)))
            + sparse_ops.sparse_reduce_sum(
                row_wt_mat *
                (sp_residual_sq * col_wt_mat))) * normalization_factor

        if self._regularization is not None:
            regularization = self._regularization * (
                math_ops.trace(partial_row_gramian) * normalization_factor +
                math_ops.trace(gramian))
        else:
            regularization = constant_op.constant(0.)

        sum_weights = self._unobserved_weight * math_ops.cast(
            total_rows * total_cols, dtypes.float32)
        if self._row_weights is not None and self._col_weights is not None:
            ones = sparse_tensor.SparseTensor(
                indices=loss_sp_input.indices,
                values=array_ops.ones(array_ops.shape(loss_sp_input.values)),
                dense_shape=loss_sp_input.dense_shape)
            sum_weights += sparse_ops.sparse_reduce_sum(
                row_wt_mat * (ones * col_wt_mat)) * normalization_factor

        return (new_left_values, update_op, unregularized_loss, regularization,
                sum_weights)
Exemple #33
0
def embedding_lookup_sparse(params,
                            sp_ids,
                            sp_weights,
                            partition_strategy="mod",
                            name=None,
                            combiner=None,
                            max_norm=None):
    """Computes embeddings for the given ids and weights.

  This op assumes that there is at least one id for each row in the dense tensor
  represented by sp_ids (i.e. there are no rows with empty features), and that
  all the indices of sp_ids are in canonical row-major order.

  It also assumes that all id values lie in the range [0, p0), where p0
  is the sum of the size of params along dimension 0.

  Args:
    params: A single tensor representing the complete embedding tensor,
      or a list of P tensors all of same shape except for the first dimension,
      representing sharded embedding tensors.  Alternatively, a
      `PartitionedVariable`, created by partitioning along dimension 0. Each
      element must be appropriately sized for the given `partition_strategy`.
    sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
      where N is typically batch size and M is arbitrary.
    sp_weights: either a SparseTensor of float / double weights, or None to
      indicate all weights should be taken to be 1. If specified, sp_weights
      must have exactly the same shape and indices as sp_ids.
    partition_strategy: A string specifying the partitioning strategy, relevant
      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
      is `"mod"`. See `tf.nn.embedding_lookup` for more details.
    name: Optional name for the op.
    combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
      and "sum" are supported.
      "sum" computes the weighted sum of the embedding results for each row.
      "mean" is the weighted sum divided by the total weight.
      "sqrtn" is the weighted sum divided by the square root of the sum of the
      squares of the weights.
    max_norm: If provided, each embedding is normalized to have l2 norm equal
      to max_norm before combining.

  Returns:
    A dense tensor representing the combined embeddings for the
    sparse ids. For each row in the dense tensor represented by sp_ids, the op
    looks up the embeddings for all ids in that row, multiplies them by the
    corresponding weight, and combines these embeddings as specified.

    In other words, if

      shape(combined params) = [p0, p1, ..., pm]

    and

      shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]

    then

      shape(output) = [d0, d1, ..., dn-1, p1, ..., pm].

    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are

      [0, 0]: id 1, weight 2.0
      [0, 1]: id 3, weight 0.5
      [1, 0]: id 0, weight 1.0
      [2, 3]: id 1, weight 3.0

    with `combiner`="mean", then the output will be a 3x20 matrix where

      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
      output[1, :] = (params[0, :] * 1.0) / 1.0
      output[2, :] = (params[1, :] * 3.0) / 3.0

  Raises:
    TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
      None nor SparseTensor.
    ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
  """
    if combiner is None:
        logging.warn("The default value of combiner will change from \"mean\" "
                     "to \"sqrtn\" after 2016/11/01.")
        combiner = "mean"
    if combiner not in ("mean", "sqrtn", "sum"):
        raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
    if isinstance(params, variables.PartitionedVariable):
        params = list(params)  # Iterate to get the underlying Variables.
    if not isinstance(params, list):
        params = [params]
    if not isinstance(sp_ids, sparse_tensor.SparseTensor):
        raise TypeError("sp_ids must be SparseTensor")
    ignore_weights = sp_weights is None
    if not ignore_weights:
        if not isinstance(sp_weights, sparse_tensor.SparseTensor):
            raise TypeError("sp_weights must be either None or SparseTensor")
        sp_ids.values.get_shape().assert_is_compatible_with(
            sp_weights.values.get_shape())
        sp_ids.indices.get_shape().assert_is_compatible_with(
            sp_weights.indices.get_shape())
        sp_ids.dense_shape.get_shape().assert_is_compatible_with(
            sp_weights.dense_shape.get_shape())
        # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
        # sp_weights have equal indices and shapes.

    with ops.name_scope(name, "embedding_lookup_sparse",
                        params + [sp_ids]) as name:
        segment_ids = sp_ids.indices[:, 0]
        if segment_ids.dtype != dtypes.int32:
            segment_ids = math_ops.cast(segment_ids, dtypes.int32)

        ids = sp_ids.values
        if ignore_weights:
            ids, idx = array_ops.unique(ids)
        else:
            idx = None

        embeddings = embedding_lookup(params,
                                      ids,
                                      partition_strategy=partition_strategy,
                                      max_norm=max_norm)
        if not ignore_weights:
            weights = sp_weights.values
            if weights.dtype != embeddings.dtype:
                weights = math_ops.cast(weights, embeddings.dtype)

            # Reshape weights to allow broadcast
            ones = array_ops.fill(
                array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
            bcast_weights_shape = array_ops.concat(
                [array_ops.shape(weights), ones], 0)

            orig_weights_shape = weights.get_shape()
            weights = array_ops.reshape(weights, bcast_weights_shape)

            # Set the weight shape, since after reshaping to bcast_weights_shape,
            # the shape becomes None.
            if embeddings.get_shape().ndims is not None:
                weights.set_shape(
                    orig_weights_shape.concatenate(
                        [1 for _ in range(embeddings.get_shape().ndims - 1)]))

            embeddings *= weights

            if combiner == "sum":
                embeddings = math_ops.segment_sum(embeddings,
                                                  segment_ids,
                                                  name=name)
            elif combiner == "mean":
                embeddings = math_ops.segment_sum(embeddings, segment_ids)
                weight_sum = math_ops.segment_sum(weights, segment_ids)
                embeddings = math_ops.div(embeddings, weight_sum, name=name)
            elif combiner == "sqrtn":
                embeddings = math_ops.segment_sum(embeddings, segment_ids)
                weights_squared = math_ops.pow(weights, 2)
                weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
                weight_sum_sqrt = math_ops.sqrt(weight_sum)
                embeddings = math_ops.div(embeddings,
                                          weight_sum_sqrt,
                                          name=name)
            else:
                assert False, "Unrecognized combiner"
        else:
            assert idx is not None
            if combiner == "sum":
                embeddings = math_ops.sparse_segment_sum(embeddings,
                                                         idx,
                                                         segment_ids,
                                                         name=name)
            elif combiner == "mean":
                embeddings = math_ops.sparse_segment_mean(embeddings,
                                                          idx,
                                                          segment_ids,
                                                          name=name)
            elif combiner == "sqrtn":
                embeddings = math_ops.sparse_segment_sqrt_n(embeddings,
                                                            idx,
                                                            segment_ids,
                                                            name=name)
            else:
                assert False, "Unrecognized combiner"

        return embeddings
 def fn(x):
   return array_ops.unique(x).y  # Unique is not supported by XLA
 def body(i, a):
   return i + 1 + array_ops.unique([i]).y[0], \
       control_flow_ops.cond(i > 2, lambda: a + (x**2), lambda: a + 3)
 def f1(self, x):
   return array_ops.unique(x).y
 def log_2plus_unique_x(x):
     unique_values, unique_pos = array_ops.unique(x)
     return math_ops.log(2.0 + unique_values), unique_pos
Exemple #38
0
  def _process_input_helper(self,
                            update_row_factors,
                            sp_input=None,
                            transpose_input=False,
                            row_weights=None):
    """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update or project the row_factors, else
        update or project the column factors.
      sp_input: Please refer to comments for update_row_factors,
        update_col_factors, project_row_factors, and project_col_factors for
        restrictions.
      transpose_input: If True, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.
      row_weights: If not None, this is the row/column weights to be used for
        the update or projection. If None, use the corresponding weights from
        the model. Note that the feature (column/row) weights will be
        determined by the model. When not None, it can either be a scalar or
        a rank-1 tensor with the same number of elements as the number of rows
        of columns to be updated/projected.

    Returns:
      A tuple consisting of the following two elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
    """
    assert isinstance(sp_input, sparse_tensor.SparseTensor)

    if update_row_factors:
      left = self._row_factors
      right_factors = self._col_factors_cache
      row_wt = self._row_wt_cache
      col_wt = self._col_wt_cache
      sharding_func = WALSModel._get_sharding_func(self._input_rows,
                                                   self._num_row_shards)
      gramian = self._col_gramian_cache
    else:
      left = self._col_factors
      right_factors = self._row_factors_cache
      row_wt = self._col_wt_cache
      col_wt = self._row_wt_cache
      sharding_func = WALSModel._get_sharding_func(self._input_cols,
                                                   self._num_col_shards)
      gramian = self._row_gramian_cache
      transpose_input = not transpose_input

    # Note that the row indices of sp_input are based on the original full input
    # Here we reindex the rows and give them contiguous ids starting at 0.
    # We use tf.unique to achieve this reindexing. Note that this is done so
    # that the downstream kernel can assume that the input is "dense" along the
    # row dimension.
    row_ids, col_ids = array_ops.split(
        value=sp_input.indices, num_or_size_splits=2, axis=1)
    update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0])
    update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0])
    col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1)
    row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1)

    if transpose_input:
      update_indices = update_col_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64)
      ]
      gather_indices = update_row_indices
    else:
      update_indices = update_row_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64)
      ]
      gather_indices = update_col_indices

    num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64)
    col_shape = [num_rows]
    right = embedding_ops.embedding_lookup(
        right_factors, gather_indices, partition_strategy="div")
    new_sp_indices = array_ops.concat([row_ids, col_ids], 1)
    new_sp_shape = (array_ops.concat([row_shape, col_shape], 0) if
                    transpose_input else
                    array_ops.concat([col_shape, row_shape], 0))
    new_sp_input = sparse_tensor.SparseTensor(
        indices=new_sp_indices,
        values=sp_input.values,
        dense_shape=new_sp_shape)

    # Compute lhs and rhs of the normal equations
    total_lhs = (self._unobserved_weight * gramian)
    if self._regularization is not None:
      total_lhs += self._regularization
    if self._row_weights is None:
      # Special case of ALS. Use a much simpler update rule.
      total_rhs = (self._unobserved_weight *
                   sparse_ops.sparse_tensor_dense_matmul(
                       new_sp_input, right, adjoint_a=transpose_input))
      # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of
      # transposing explicitly.
      # TODO(rmlarsen): multi-thread tf.matrix_solve.
      new_left_values = array_ops.transpose(
          linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs)))
    else:
      if row_weights is None:
        # TODO(yifanchen): Add special handling for single shard without using
        # embedding_lookup and perform benchmarks for those cases. Same for
        # col_weights lookup below.
        row_weights_slice = embedding_ops.embedding_lookup(
            row_wt, update_indices, partition_strategy="div")
      else:
        with ops.control_dependencies(
            [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]):
          row_weights_slice = control_flow_ops.cond(
              math_ops.equal(array_ops.rank(row_weights), 0),
              lambda: (array_ops.ones([array_ops.shape(update_indices)[0]]) * row_weights),
              lambda: math_ops.cast(row_weights, dtypes.float32))

      col_weights = embedding_ops.embedding_lookup(
          col_wt, gather_indices, partition_strategy="div")
      partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs(
          right,
          col_weights,
          self._unobserved_weight,
          row_weights_slice,
          new_sp_input.indices,
          new_sp_input.values,
          num_rows,
          transpose_input,
          name="wals_compute_partial_lhs_rhs")
      total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
      total_rhs = array_ops.expand_dims(total_rhs, -1)
      new_left_values = array_ops.squeeze(
          linalg_ops.matrix_solve(total_lhs, total_rhs), [2])

    return (new_left_values, self.scatter_update(left, update_indices,
                                                 new_left_values,
                                                 sharding_func))
def scattered_embedding_lookup_sparse(params,
                                      sparse_values,
                                      dimension,
                                      combiner=None,
                                      default_value=None,
                                      name=None,
                                      hash_key=None):
  """Looks up embeddings of a sparse feature using parameter hashing.

  See `tf.contrib.layers.scattered_embedding_lookup` for embedding with hashing.

  Args:
    params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`.
      Each tensor must be of rank 1 with fully-defined shape.
    sparse_values: A 2-D `SparseTensor` containing the values to be embedded.
      Some rows may be empty.
    dimension: Embedding dimension
    combiner: A string specifying how to combine embedding results for each
        entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
        the default.
    default_value: The value to use for an entry with no features.
    name: An optional name for this op.
    hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
      function to combine the crosses fingerprints on SparseFeatureCrossOp
      (optional).

  Returns:
     Dense tensor with shape [N, dimension] with N the number of rows in
       sparse_values.

  Raises:
    TypeError: If sparse_values is not a SparseTensor.
    ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
  """
  if combiner is None:
    logging.warn("The default value of combiner will change from \"mean\" "
                 "to \"sqrtn\" after 2016/11/01.")
    combiner = "mean"
  if isinstance(params, variables.PartitionedVariable):
    params = list(params)
  if not isinstance(params, list):
    params = [params]
  if not isinstance(sparse_values, sparse_tensor.SparseTensor):
    raise TypeError("sparse_values must be SparseTensor")

  with ops.name_scope(name, "scattered_embedding_lookup_sparse",
                      params + [sparse_values]) as scope:
    # Fill in the empty rows.
    if default_value is None:
      # Random default values to reduce the risk of collision.
      if sparse_values.dtype == dtypes.string:
        default_value = "6ZxWzWOHxZ"
      else:
        default_value = 1288896567
    sparse_values, _ = sparse_ops.sparse_fill_empty_rows(
        sparse_values, default_value)

    segment_ids = sparse_values.indices[:, 0]
    if segment_ids.dtype != dtypes.int32:
      segment_ids = math_ops.cast(segment_ids, dtypes.int32)

    values = sparse_values.values
    values, idx = array_ops.unique(values)

    embeddings = scattered_embedding_lookup(
        params, values, dimension, hash_key=hash_key)

    if combiner == "sum":
      embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
                                               name=scope)
    elif combiner == "mean":
      embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids,
                                                name=scope)
    elif combiner == "sqrtn":
      embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, segment_ids,
                                                  name=scope)
    else:
      raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.")

    return embeddings
Exemple #40
0
 def _unique(x):
   u = array_ops.unique(x)
   y = array_ops.pad(
       u.y, [[0, _get_dim(u.idx, 0) - _get_dim(u.y, 0)]])
   y = math_ops.cast(y, dtypes.int64)
   return [y, u.idx]
def embedding_lookup_sparse_with_distributed_aggregation(
    params,
    sp_ids,
    sp_weights,
    partition_strategy="mod",
    name=None,
    combiner=None,
    max_norm=None):
  """Computes embeddings for the given ids and weights.

  Embeddings belonging to same param are aggregated on that device first. This
  op is intended to decrease data transmission and improve parallelism. See
  `tf.nn.embedding_lookup_sparse` for the functionality and example of this op.

  Args:
    params: A single tensor representing the complete embedding tensor,
      or a list of P tensors all of same shape except for the first dimension,
      representing sharded embedding tensors.  Alternatively, a
      `PartitionedVariable`, created by partitioning along dimension 0. Each
      element must be appropriately sized for the given `partition_strategy`.
    sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
      where N is typically batch size and M is arbitrary.
    sp_weights: either a SparseTensor of float / double weights, or None to
      indicate all weights should be taken to be 1. If specified, sp_weights
      must have exactly the same shape and indices as sp_ids.
    partition_strategy: A string specifying the partitioning strategy, relevant
      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
      is `"mod"`. See `tf.nn.embedding_lookup` for more details.
    name: Optional name for the op.
    combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
      and "sum" are supported.
      "sum" computes the weighted sum of the embedding results for each row.
      "mean" is the weighted sum divided by the total weight.
      "sqrtn" is the weighted sum divided by the square root of the sum of the
      squares of the weights.
    max_norm: If not None, each embedding is normalized to have l2 norm equal
      to max_norm before combining.

  Returns:
    A dense tensor representing the combined embeddings for the
    sparse ids. For each row in the dense tensor represented by sp_ids, the op
    looks up the embeddings for all ids in that row, multiplies them by the
    corresponding weight, and combines these embeddings as specified.

  Raises:
    TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
      None nor SparseTensor.
    ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
  """
  if combiner is None:
    logging.warn("The default value of combiner will change from \"mean\" "
                 "to \"sqrtn\" after 2016/11/01.")
    combiner = "mean"
  if combiner not in ("mean", "sqrtn", "sum"):
    raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
  if isinstance(params, variables.PartitionedVariable):
    params = list(params)  # Iterate to get the underlying Variables.
  if not isinstance(params, list):
    params = [params]
  if not isinstance(sp_ids, sparse_tensor.SparseTensor):
    raise TypeError("sp_ids must be SparseTensor")
  ignore_weights = sp_weights is None
  if not ignore_weights:
    if not isinstance(sp_weights, sparse_tensor.SparseTensor):
      raise TypeError("sp_weights must be either None or SparseTensor")
    sp_ids.values.get_shape().assert_is_compatible_with(
        sp_weights.values.get_shape())
    sp_ids.indices.get_shape().assert_is_compatible_with(
        sp_weights.indices.get_shape())
    sp_ids.dense_shape.get_shape().assert_is_compatible_with(
        sp_weights.dense_shape.get_shape())
    # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
    # sp_weights have equal indices and shapes.

  with ops.name_scope(name, "embedding_lookup_sparse",
                      params + [sp_ids]) as name:
    segment_ids = sp_ids.indices[:, 0]
    if segment_ids.dtype != dtypes.int32:
      segment_ids = math_ops.cast(segment_ids, dtypes.int32)

    ids = sp_ids.values
    if ignore_weights:
      ids, idx = array_ops.unique(ids)
    else:
      idx = None

    weights = None if ignore_weights else sp_weights.values
    embeddings = _embedding_lookup_with_distributed_aggregation(
        params,
        ids,
        partition_strategy=partition_strategy,
        max_norm=max_norm,
        weights=weights,
        idx=idx,
        segment_ids=segment_ids)
    # Set weights to all one if ignore weights.
    if ignore_weights:
      weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1)
    if weights.dtype != embeddings.dtype:
      weights = math_ops.cast(weights, embeddings.dtype)
    # Reshape weights.
    ones = array_ops.fill(
        array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
    bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0)
    orig_weights_shape = weights.get_shape()
    weights = array_ops.reshape(weights, bcast_weights_shape)
    if embeddings.get_shape().ndims is not None:
      weights.set_shape(
          orig_weights_shape.concatenate(
              [1 for _ in range(embeddings.get_shape().ndims - 1)]))

    if combiner == "mean":
      weight_sum = math_ops.segment_sum(weights, segment_ids)
      embeddings = math_ops.div(embeddings, weight_sum)
    elif combiner == "sqrtn":
      weights_squared = math_ops.pow(weights, 2)
      weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
      weight_sum_sqrt = math_ops.sqrt(weight_sum)
      embeddings = math_ops.div(embeddings, weight_sum_sqrt)
    elif combiner != "sum":
      assert False, "Unrecognized combiner"
    return embeddings
Exemple #42
0
  def _process_input_helper(self,
                            update_row_factors,
                            sp_input=None,
                            transpose_input=False,
                            row_weights=None):
    """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update or project the row_factors, else
        update or project the column factors.
      sp_input: Please refer to comments for update_row_factors,
        update_col_factors, project_row_factors, and project_col_factors for
        restrictions.
      transpose_input: If True, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.
      row_weights: If not None, this is the row/column weights to be used for
        the update or projection. If None, use the corresponding weights from
        the model. Note that the feature (column/row) weights will be
        determined by the model. When not None, it can either be a scalar or
        a rank-1 tensor with the same number of elements as the number of rows
        of columns to be updated/projected.

    Returns:
      A tuple consisting of the following elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
      unregularized_loss: A tensor (scalar) that contains the normalized
        minibatch loss corresponding to sp_input, without the regularization
        term. Add the regularization term below to yield the loss.
      regularization: A tensor (scalar) that contains the normalized
        regularization term for the minibatch loss corresponding to sp_input.
      sum_weights: The sum of the weights corresponding to sp_input. This
        can be used with unregularized loss to caluclate the root weighted
        squared error.
    """
    assert isinstance(sp_input, sparse_tensor.SparseTensor)

    if update_row_factors:
      left = self._row_factors
      right_factors = self._col_factors_cache
      row_wt = self._row_wt_cache
      col_wt = self._col_wt_cache
      total_rows = self._input_rows
      total_cols = self._input_cols
      sharding_func = WALSModel._get_sharding_func(self._input_rows,
                                                   self._num_row_shards)
      gramian = self._col_gramian_cache
    else:
      left = self._col_factors
      right_factors = self._row_factors_cache
      row_wt = self._col_wt_cache
      col_wt = self._row_wt_cache
      total_rows = self._input_cols
      total_cols = self._input_rows
      sharding_func = WALSModel._get_sharding_func(self._input_cols,
                                                   self._num_col_shards)
      gramian = self._row_gramian_cache
      transpose_input = not transpose_input

    # Note that the row indices of sp_input are based on the original full input
    # Here we reindex the rows and give them contiguous ids starting at 0.
    # We use tf.unique to achieve this reindexing. Note that this is done so
    # that the downstream kernel can assume that the input is "dense" along the
    # row dimension.
    row_ids, col_ids = array_ops.split(
        value=sp_input.indices, num_or_size_splits=2, axis=1)
    update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0])
    update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0])
    col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1)
    row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1)

    if transpose_input:
      update_indices = update_col_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64)
      ]
      gather_indices = update_row_indices
    else:
      update_indices = update_row_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64)
      ]
      gather_indices = update_col_indices

    num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64)
    col_shape = [num_rows]
    right = embedding_ops.embedding_lookup(
        right_factors, gather_indices, partition_strategy="div")
    new_sp_indices = array_ops.concat([row_ids, col_ids], 1)
    new_sp_shape = (array_ops.concat([row_shape, col_shape], 0)
                    if transpose_input else
                    array_ops.concat([col_shape, row_shape], 0))
    new_sp_input = sparse_tensor.SparseTensor(
        indices=new_sp_indices,
        values=sp_input.values,
        dense_shape=new_sp_shape)

    # Compute lhs and rhs of the normal equations
    total_lhs = (self._unobserved_weight * gramian)
    if self._regularization_matrix is not None:
      total_lhs += self._regularization_matrix
    if self._row_weights is None:
      # Special case of ALS. Use a much simpler update rule.
      total_rhs = (
          self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul(
              new_sp_input, right, adjoint_a=transpose_input))
      # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of
      # transposing explicitly.
      # TODO(rmlarsen): multi-thread tf.matrix_solve.
      new_left_values = array_ops.transpose(
          linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs)))
    else:
      if row_weights is None:
        # TODO(yifanchen): Add special handling for single shard without using
        # embedding_lookup and perform benchmarks for those cases. Same for
        # col_weights lookup below.
        row_weights_slice = embedding_ops.embedding_lookup(
            row_wt, update_indices, partition_strategy="div")
      else:
        num_indices = array_ops.shape(update_indices)[0]
        with ops.control_dependencies(
            [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]):
          row_weights_slice = control_flow_ops.cond(
              math_ops.equal(array_ops.rank(row_weights), 0),
              lambda: (array_ops.ones([num_indices]) * row_weights),
              lambda: math_ops.cast(row_weights, dtypes.float32))

      col_weights = embedding_ops.embedding_lookup(
          col_wt, gather_indices, partition_strategy="div")
      partial_lhs, total_rhs = (
          gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
              right,
              col_weights,
              self._unobserved_weight,
              row_weights_slice,
              new_sp_input.indices,
              new_sp_input.values,
              num_rows,
              transpose_input,
              name="wals_compute_partial_lhs_rhs"))
      total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
      total_rhs = array_ops.expand_dims(total_rhs, -1)
      new_left_values = array_ops.squeeze(
          linalg_ops.matrix_solve(total_lhs, total_rhs), [2])

    update_op_name = "row_update" if update_row_factors else "col_update"
    update_op = self.scatter_update(
        left,
        update_indices,
        new_left_values,
        sharding_func,
        name=update_op_name)

    # Create the loss subgraph
    loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input)
                     if transpose_input else new_sp_input)
    # sp_approx is the low rank estimate of the input matrix, formed by
    # computing the product <u_i, v_j> for (i, j) in loss_sp_input.indices.
    sp_approx_vals = gen_factorization_ops.masked_matmul(
        new_left_values,
        right,
        loss_sp_input.indices,
        transpose_a=False,
        transpose_b=True)
    sp_approx = sparse_tensor.SparseTensor(
        loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape)
    sp_approx_sq = math_ops.square(sp_approx)
    sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1))
    sp_residual_sq = math_ops.square(sp_residual)
    row_wt_mat = (constant_op.constant(0.)
                  if self._row_weights is None else array_ops.expand_dims(
                      row_weights_slice, 1))
    col_wt_mat = (constant_op.constant(0.)
                  if self._col_weights is None else array_ops.expand_dims(
                      col_weights, 0))

    # We return the normalized loss
    partial_row_gramian = math_ops.matmul(
        new_left_values, new_left_values, transpose_a=True)
    normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32)

    unregularized_loss = (
        self._unobserved_weight * (  # pyformat line break
            sparse_ops.sparse_reduce_sum(sp_residual_sq) -  # pyformat break
            sparse_ops.sparse_reduce_sum(sp_approx_sq) +  # pyformat break
            math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) +
        sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat))
    ) * normalization_factor

    if self._regularization is not None:
      regularization = self._regularization * (
          math_ops.trace(partial_row_gramian) * normalization_factor +
          math_ops.trace(gramian))
    else:
      regularization = constant_op.constant(0.)

    sum_weights = self._unobserved_weight * math_ops.cast(
        total_rows * total_cols, dtypes.float32)
    if self._row_weights is not None and self._col_weights is not None:
      ones = sparse_tensor.SparseTensor(
          indices=loss_sp_input.indices,
          values=array_ops.ones(array_ops.shape(loss_sp_input.values)),
          dense_shape=loss_sp_input.dense_shape)
      sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * (
          ones * col_wt_mat)) * normalization_factor

    return (new_left_values, update_op, unregularized_loss, regularization,
            sum_weights)
  def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers,
                              cluster_centers_var, total_counts):
    """Creates an op for training for mini batch case.

    Args:
      inputs: list of input Tensors.
      cluster_idx_list: A vector (or list of vectors). Each element in the
        vector corresponds to an input row in 'inp' and specifies the cluster id
        corresponding to the input.
      cluster_centers: Tensor of cluster centers, possibly normalized.
      cluster_centers_var: Tensor Ref of cluster centers.
      total_counts: Tensor Ref of cluster counts.

    Returns:
      An op for doing an update of mini-batch k-means.
    """
    update_ops = []
    for inp, cluster_idx in zip(inputs, cluster_idx_list):
      with ops.colocate_with(inp):
        assert total_counts is not None
        cluster_idx = array_ops.reshape(cluster_idx, [-1])
        # Dedupe the unique ids of cluster_centers being updated so that updates
        # can be locally aggregated.
        unique_ids, unique_idx = array_ops.unique(cluster_idx)
        num_unique_cluster_idx = array_ops.size(unique_ids)
        # Fetch the old values of counts and cluster_centers.
        with ops.colocate_with(total_counts):
          old_counts = array_ops.gather(total_counts, unique_ids)
        with ops.colocate_with(cluster_centers):
          old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
        # Locally aggregate the increment to counts.
        count_updates = math_ops.unsorted_segment_sum(
            array_ops.ones_like(
                unique_idx, dtype=total_counts.dtype),
            unique_idx,
            num_unique_cluster_idx)
        # Locally compute the sum of inputs mapped to each id.
        # For a cluster with old cluster value x, old count n, and with data
        # d_1,...d_k newly assigned to it, we recompute the new value as
        # x += (sum_i(d_i) - k * x) / (n + k).
        # Compute sum_i(d_i), see comment above.
        cluster_center_updates = math_ops.unsorted_segment_sum(
            inp, unique_idx, num_unique_cluster_idx)
        # Shape to enable broadcasting count_updates and learning_rate to inp.
        # It extends the shape with 1's to match the rank of inp.
        broadcast_shape = array_ops.concat(
            [
                array_ops.reshape(num_unique_cluster_idx, [1]), array_ops.ones(
                    array_ops.reshape(array_ops.rank(inp) - 1, [1]),
                    dtype=dtypes.int32)
            ],
            0)
        # Subtract k * x, see comment above.
        cluster_center_updates -= math_ops.cast(
            array_ops.reshape(count_updates, broadcast_shape),
            inp.dtype) * old_cluster_centers
        learning_rate = math_ops.reciprocal(
            math_ops.cast(old_counts + count_updates, inp.dtype))
        learning_rate = array_ops.reshape(learning_rate, broadcast_shape)
        # scale by 1 / (n + k), see comment above.
        cluster_center_updates *= learning_rate
        # Apply the updates.
      update_counts = state_ops.scatter_add(total_counts, unique_ids,
                                            count_updates)
      update_cluster_centers = state_ops.scatter_add(cluster_centers_var,
                                                     unique_ids,
                                                     cluster_center_updates)
      update_ops.extend([update_counts, update_cluster_centers])
    return control_flow_ops.group(*update_ops)
 def count_cols(self, sp_input):
   return math_ops.cast(
       array_ops.shape(array_ops.unique(sp_input.indices[:, 1])[0])[0],
       dtypes.float32)