def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self):
    with self.session(use_gpu=True):
      params = np.ones((3, 3), dtype=np.float32)

      indices_empty = np.empty((0, 2), dtype=np.int32)
      gather_nd_ok_t = array_ops.gather_nd(params, indices_empty)
      gather_nd_ok_val = gather_nd_ok_t.eval()
      self.assertEqual([0], gather_nd_ok_t.get_shape())
      self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val)

      indices_empty = np.empty((0, 1), dtype=np.int32)
      gather_nd_ok_t = array_ops.gather_nd(params, indices_empty)
      gather_nd_ok_val = gather_nd_ok_t.eval()
      self.assertEqual([0, 3], gather_nd_ok_t.get_shape())
      self.assertAllClose(np.empty((0, 3), dtype=np.float32), gather_nd_ok_val)

      params_empty = np.empty((0, 3), dtype=np.float32)
      indices_empty = np.empty((0, 2), dtype=np.int32)
      gather_nd_ok_t = array_ops.gather_nd(params_empty, indices_empty)
      gather_nd_ok_val = gather_nd_ok_t.eval()
      self.assertEqual([0], gather_nd_ok_t.get_shape())
      self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val)

      params_empty = np.empty((0, 3), dtype=np.float32)
      indices_nonempty = np.zeros((1, 2), dtype=np.int32)
      gather_nd_break_t = array_ops.gather_nd(params_empty, indices_nonempty)
      with self.assertRaisesOpError(
          r"Requested more than 0 entries, but params is empty."):
        gather_nd_break_t.eval()
      self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val)
  def _get_coordinatewise_learning_rate(self, grad, var):
    # Compute the learning rate using a moving average for the diagonal of BB^T
    avg_first = self.get_slot(var, 'first_moment')
    avg_second = self.get_slot(var, 'second_moment')
    decay_tensor = math_ops.cast(self._decay_tensor, var.dtype)
    batch_size = math_ops.cast(self._batch_size_tensor, var.dtype)

    # Create an estimator for the moving average of gradient mean and variance
    # via Welford's algorithm
    if isinstance(grad, ops.Tensor):
      delta = grad - avg_first
      first_moment_update = avg_first.assign_add(
          array_ops.where(self._counter < 1, math_ops.cast(1, var.dtype),
                          1. - decay_tensor) * delta)

      with ops.control_dependencies([first_moment_update]):
        second_moment_update = avg_second.assign_add(
            math_ops.cast(self._counter < 1, var.dtype) *
            -(1. - decay_tensor) * (
                avg_second - decay_tensor  * math_ops.square(delta)))
      diag_preconditioner = control_flow_ops.with_dependencies(
          [second_moment_update],
          clip_ops.clip_by_value(avg_second, 1e-12, 1e12))
    elif isinstance(grad, ops.IndexedSlices):
      delta = grad.values - array_ops.gather_nd(avg_first, grad.indices)
      first_moment_update = state_ops.scatter_add(
          avg_first,
          grad.indices,
          array_ops.where(self._counter < 1,
                          math_ops.cast(1., var.dtype),
                          1. - decay_tensor) * delta)

      with ops.control_dependencies([first_moment_update]):
        avg_second = state_ops.scatter_add(
            avg_second,
            grad.indices,
            math_ops.cast(self._counter < 1, var.dtype) *
            -(1. - decay_tensor) * (
                array_ops.gather_nd(avg_second, grad.indices) - decay_tensor *
                math_ops.square(delta)))
        avg_second = array_ops.gather_nd(avg_second, grad.indices)
        # TODO(b/70783772)
        diag_preconditioner = clip_ops.clip_by_value(avg_second, 1e-12, 1e12)
    else:
      raise errors.InvalidArgumentError(
          None, None, 'grad must of type Tensor or IndexedSlice')

    diag_preconditioner *= batch_size

    if self._use_single_learning_rate:
      diag_preconditioner = math_ops.reduce_mean(diag_preconditioner)

    # From Theorem 2 Corollary 1 of Mandt et al. 2017
    return 2. * batch_size / (
        math_ops.cast(self._total_num_examples, var.dtype.base_dtype) *
        diag_preconditioner)
 def testUnknownIndices(self):
   params = constant_op.constant([[0, 1, 2]])
   indices = array_ops.placeholder(dtypes.int32)
   gather_nd_t = array_ops.gather_nd(params, indices)
   shape = gather_nd_t.get_shape()
   self.assertEqual(None, shape.ndims)
   self.assertEqual(None, tensor_shape.dimension_value(shape[0]))
  def testGradientsRank7Elements(self):
    # Shape [1,1,2,1,1,2,2]
    indices = constant_op.constant(
        [[[
            [[[[0, 0, 0, 0, 0, 1], [0, 0, 1, 0, 0, 0]]]],
            [[[[0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 1]]]]
        ]]],
        dtype=dtypes.int32)
    inputs = constant_op.constant(
        [[[
            [[[[1, 3], [5, 7]]]],
            [[[[2, 4], [6, 8]]]]
        ]]], dtype=dtypes.float64)
    outputs = array_ops.gather_nd(inputs, indices)

    grad_vals = constant_op.constant(
        [[[
            [[[[1, 2], [3, 4]]]],
            [[[[5, 6], [7, 8]]]]
        ]]], dtype=dtypes.float64)
    grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
    expected_grads = np.array(
        [[[
            [[[[5, 6], [1, 2]]]],
            [[[[3, 4], [7, 8]]]]
        ]]], dtype=np.float64)
    with self.session(use_gpu=True):
      self.assertAllEqual(expected_grads, grads.eval())
Esempio n. 5
0
 def _single_seq_fn():
   batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0]
   example_inds = array_ops.reshape(
       math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
   return array_ops.gather_nd(
       array_ops.squeeze(inputs, [1]),
       array_ops.concat([example_inds, tag_indices], axis=1))
Esempio n. 6
0
def _TensorScatterUpdateGrad(op, grad):
  indices = op.inputs[1]
  updates_grad = array_ops.gather_nd(grad, indices)
  tensor_grad = array_ops.tensor_scatter_update(
      array_ops.identity(grad), indices,
      array_ops.zeros_like(op.inputs[2], dtype=grad.dtype))
  return [tensor_grad, None, updates_grad]
Esempio n. 7
0
def _SparseDenseCwiseMulOrDivGrad(op, grad, is_mul):
  """Common code for SparseDenseCwise{Mul,Div} gradients."""
  x_indices = op.inputs[0]
  x_shape = op.inputs[2]
  y = op.inputs[3]

  y_shape = math_ops.to_int64(array_ops.shape(y))
  num_added_dims = array_ops.expand_dims(
      array_ops.size(x_shape) - array_ops.size(y_shape), 0)
  augmented_y_shape = array_ops.concat(
      [array_ops.ones(num_added_dims, ops.dtypes.int64), y_shape], 0)

  scaling = x_shape // augmented_y_shape
  scaled_indices = x_indices // scaling
  scaled_indices = array_ops.slice(scaled_indices,
                                   array_ops.concat([[0], num_added_dims], 0),
                                   [-1, -1])
  dense_vals = array_ops.gather_nd(y, scaled_indices)

  if is_mul:
    dx = grad * dense_vals
    dy_val = grad * op.inputs[1]
  else:
    dx = grad / dense_vals
    dy_val = grad * (-op.inputs[1] / math_ops.square(dense_vals))
  # indices can repeat after scaling, so we can't use sparse_to_dense().
  dy = sparse_ops.sparse_add(
      array_ops.zeros_like(y),
      sparse_tensor.SparseTensor(scaled_indices, dy_val, y_shape))

  # (sp_indices, sp_vals, sp_shape, dense)
  return (None, dx, None, dy)
Esempio n. 8
0
def dense_to_sparse_tensor(dense_tensor, ignore_value=None):
  """Converts dense `Tensor` to `SparseTensor`, dropping `ignore_value` cells.

  Args:
    dense_tensor: A `Tensor`.
    ignore_value: Entries in `dense_tensor` equal to this value will be
      absent from the return `SparseTensor`. If `None`, default value of
      `dense_tensor` dtype will be used (e.g. '' for `str`, 0 for `int`).

  Returns:
    A `SparseTensor` with the same shape as `dense_tensor`.

  Raises:
    ValueError: when `dense_tensor`'s rank is `None`.
  """
  with ops.name_scope("DenseToSparseTensor"):
    dense_tensor = ops.convert_to_tensor(dense_tensor)
    ignore_value = _ignore_value_tensor(dense_tensor.dtype, ignore_value)
    indices = array_ops.where(
        math_ops.not_equal(dense_tensor, ignore_value), name="indices")
    return sparse_tensor.SparseTensor(
        indices=indices,
        values=array_ops.gather_nd(dense_tensor, indices, name="values"),
        dense_shape=array_ops.shape(
            dense_tensor, out_type=dtypes.int64, name="dense_shape"))
def gather_tree_from_array(t, parent_ids, sequence_length):
  """Calculates the full beams for `TensorArray`s.

  Args:
    t: A stacked `TensorArray` of size `max_time` that contains `Tensor`s of
      shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]`
      where `s` is the depth shape.
    parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`.
    sequence_length: The sequence length of shape `[batch_size, beam_width]`.

  Returns:
    A `Tensor` which is a stacked `TensorArray` of the same size and type as
    `t` and where beams are sorted in each `Tensor` according to `parent_ids`.
  """
  max_time = parent_ids.shape[0].value or array_ops.shape(parent_ids)[0]
  batch_size = parent_ids.shape[1].value or array_ops.shape(parent_ids)[1]
  beam_width = parent_ids.shape[2].value or array_ops.shape(parent_ids)[2]

  # Generate beam ids that will be reordered by gather_tree.
  beam_ids = array_ops.expand_dims(
      array_ops.expand_dims(math_ops.range(beam_width), 0), 0)
  beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1])

  mask = array_ops.sequence_mask(
      sequence_length, maxlen=max_time, dtype=dtypes.int32)
  mask = array_ops.transpose(mask, perm=[2, 0, 1])

  # Use beam_width + 1 to mark the end of beam.
  masked_beam_ids = (beam_ids * mask) + (1 - mask) * (beam_width + 1)

  max_sequence_lengths = math_ops.to_int32(
      math_ops.reduce_max(sequence_length, axis=1))
  sorted_beam_ids = beam_search_ops.gather_tree(
      step_ids=masked_beam_ids,
      parent_ids=parent_ids,
      max_sequence_lengths=max_sequence_lengths,
      end_token=beam_width + 1)

  # For out of range steps, simply copy the same beam.
  sorted_beam_ids = array_ops.where(
      math_ops.cast(mask, dtypes.bool), x=sorted_beam_ids, y=beam_ids)

  # Generate indices for gather_nd.
  time_ind = array_ops.tile(array_ops.reshape(
      math_ops.range(max_time), [-1, 1, 1]), [1, batch_size, beam_width])
  batch_ind = array_ops.tile(array_ops.reshape(
      math_ops.range(batch_size), [-1, 1, 1]), [1, max_time, beam_width])
  batch_ind = array_ops.transpose(batch_ind, perm=[1, 0, 2])
  indices = array_ops.stack([time_ind, batch_ind, sorted_beam_ids], -1)

  # Gather from a tensor with collapsed additional dimensions.
  gather_from = t
  final_shape = array_ops.shape(gather_from)
  gather_from = array_ops.reshape(
      gather_from, [max_time, batch_size, beam_width, -1])
  ordered = array_ops.gather_nd(gather_from, indices)
  ordered = array_ops.reshape(ordered, final_shape)

  return ordered
Esempio n. 10
0
 def testGatherNdRefVariable(self):
   with self.cached_session():
     v = variables.RefVariable(constant_op.constant([[1, 2], [3, 4], [5, 6]]))
     self.evaluate(variables.global_variables_initializer())
     gather = array_ops.gather_nd(v, [[0, 1], [2, 0]])
     if not context.executing_eagerly():  # .op doesn't make sense in Eager
       self.assertEqual("GatherNd", gather.op.name)
     self.assertAllEqual([2, 5], gather)
 def testBadIndicesCPU(self):
   with self.session(use_gpu=False):
     params = [0, 1, 2]
     indices = [[[0], [7]]]  # Make this one higher rank
     gather_nd = array_ops.gather_nd(params, indices)
     with self.assertRaisesOpError(
         r"indices\[0,1\] = \[7\] does not index into param shape \[3\]"):
       self.evaluate(gather_nd)
Esempio n. 12
0
 def maybe_sample():
   """Perform scheduled sampling."""
   where_sampling = math_ops.cast(
       array_ops.where(sample_ids > -1), dtypes.int32)
   where_not_sampling = math_ops.cast(
       array_ops.where(sample_ids <= -1), dtypes.int32)
   sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)
   inputs_not_sampling = array_ops.gather_nd(
       base_next_inputs, where_not_sampling)
   sampled_next_inputs = self._embedding_fn(sample_ids_sampling)
   base_shape = array_ops.shape(base_next_inputs)
   return (array_ops.scatter_nd(indices=where_sampling,
                                updates=sampled_next_inputs,
                                shape=base_shape)
           + array_ops.scatter_nd(indices=where_not_sampling,
                                  updates=inputs_not_sampling,
                                  shape=base_shape))
 def testBadIndicesWithSlicesCPU(self):
   with self.session(use_gpu=False):
     params = [[0, 1, 2]]
     indices = [[[0], [0], [1]]]  # Make this one higher rank
     gather_nd = array_ops.gather_nd(params, indices)
     with self.assertRaisesOpError(
         r"indices\[0,2\] = \[1\] does not index into param shape \[1,3\]"):
       gather_nd.eval()
Esempio n. 14
0
 def _runGather(self, params, indices):
   with self.test_session():
     paramsp = array_ops.placeholder(params.dtype)
     indicesp = array_ops.placeholder(indices.dtype)
     with self.test_scope():
       gather_nd_t = array_ops.gather_nd(paramsp, indicesp)
     feed_dict = {paramsp: params, indicesp: indices}
     return gather_nd_t.eval(feed_dict=feed_dict)
Esempio n. 15
0
  def _verifyLu(self, x, output_idx_type=dtypes.int64):
    # Verify that Px = LU.
    lu, perm = linalg_ops.lu(x, output_idx_type=output_idx_type)

    # Prepare the lower factor of shape num_rows x num_rows
    lu_shape = np.array(lu.shape.as_list())
    batch_shape = lu_shape[:-2]
    num_rows = lu_shape[-2]
    num_cols = lu_shape[-1]

    lower = array_ops.matrix_band_part(lu, -1, 0)

    if num_rows > num_cols:
      eye = linalg_ops.eye(
          num_rows, batch_shape=batch_shape, dtype=lower.dtype)
      lower = array_ops.concat([lower, eye[..., num_cols:]], axis=-1)
    elif num_rows < num_cols:
      lower = lower[..., :num_rows]

    # Fill the diagonal with ones.
    ones_diag = array_ops.ones(
        np.append(batch_shape, num_rows), dtype=lower.dtype)
    lower = array_ops.matrix_set_diag(lower, ones_diag)

    # Prepare the upper factor.
    upper = array_ops.matrix_band_part(lu, 0, -1)

    verification = math_ops.matmul(lower, upper)

    # Permute the rows of product of the Cholesky factors.
    if num_rows > 0:
      # Reshape the product of the triangular factors and permutation indices
      # to a single batch dimension. This makes it easy to apply
      # invert_permutation and gather_nd ops.
      perm_reshaped = array_ops.reshape(perm, [-1, num_rows])
      verification_reshaped = array_ops.reshape(verification,
                                                [-1, num_rows, num_cols])
      # Invert the permutation in each batch.
      inv_perm_reshaped = map_fn.map_fn(array_ops.invert_permutation,
                                        perm_reshaped)
      batch_size = perm_reshaped.shape.as_list()[0]
      # Prepare the batch indices with the same shape as the permutation.
      # The corresponding batch index is paired with each of the `num_rows`
      # permutation indices.
      batch_indices = math_ops.cast(
          array_ops.broadcast_to(
              math_ops.range(batch_size)[:, None], perm_reshaped.shape),
          dtype=output_idx_type)
      permuted_verification_reshaped = array_ops.gather_nd(
          verification_reshaped,
          array_ops.stack([batch_indices, inv_perm_reshaped], axis=-1))

      # Reshape the verification matrix back to the original shape.
      verification = array_ops.reshape(permuted_verification_reshaped,
                                       lu_shape)

    self._verifyLuBase(x, lower, upper, perm, verification,
                       output_idx_type)
Esempio n. 16
0
 def testBadIndicesWithSlices(self):
   with self.test_session():
     params = [[0, 1, 2]]
     indices = [[[0], [0], [1]]]  # Make this one higher rank
     gather_nd = array_ops.gather_nd(params, indices)
     with self.assertRaisesOpError(
         r"flat indices\[2, :\] = \[1\] does not index into param "
         r"\(shape: \[1,3\]\)"):
       gather_nd.eval()
Esempio n. 17
0
 def testIndexScalar(self):
   with self.test_session(use_gpu=True):
     params = np.array(
         [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T
     indices = constant_op.constant([4, 1])
     gather_nd_t = array_ops.gather_nd(params, indices)
     gather_nd_val = gather_nd_t.eval()
     self.assertEqual([], gather_nd_t.get_shape())
     self.assertAllEqual(np.array(7), gather_nd_val)
Esempio n. 18
0
 def _dense_to_sparse_tensor(dense_tensor):
   """Returns a SparseTensor for the input dense_tensor."""
   ignore_value = 0.0
   sparse_indices = array_ops.where(math_ops.not_equal(
       dense_tensor, math_ops.cast(ignore_value, dense_tensor.dtype)))
   sparse_values = array_ops.gather_nd(dense_tensor, sparse_indices)
   # SparseTensor needs the shape to be converted to int64.
   int64_shape = math_ops.to_int64(array_ops.shape(dense_tensor))
   return ops.SparseTensor(sparse_indices, sparse_values, shape=int64_shape)
Esempio n. 19
0
def _SparseReduceSumGrad(op, out_grad):
    """Similar to gradient for the Sum Op (i.e. tf.reduce_sum())."""
    sp_indices = op.inputs[0]
    sp_shape = op.inputs[2]
    output_shape_kept_dims = math_ops.reduced_shape(sp_shape, op.inputs[3])
    out_grad_reshaped = array_ops.reshape(out_grad, output_shape_kept_dims)
    scale = sp_shape // math_ops.to_int64(output_shape_kept_dims)
    # (sparse_indices, sparse_values, sparse_shape, reduction_axes)
    return (None, array_ops.gather_nd(out_grad_reshaped, sp_indices // scale), None, None)
 def testParamsRankLargerThanIndexIndexScalarSlices(self):
   with self.session(use_gpu=True):
     params = np.array(
         [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T
     indices = constant_op.constant([4])
     gather_nd_t = array_ops.gather_nd(params, indices)
     gather_nd_val = gather_nd_t.eval()
     self.assertEqual([2], gather_nd_t.get_shape())
     self.assertAllEqual(np.array([-7, 7]), gather_nd_val)
  def _testSimpleDtype(self, dtype):
    with self.cached_session(use_gpu=True):
      params = constant_op.constant(np.array([8, 1, 2, 3, 7, 5], dtype=dtype))
      indices = constant_op.constant([[4], [4], [0]])
      gather_nd_t = array_ops.gather_nd(params, indices)
      gather_nd_val = gather_nd_t.eval()

    self.assertAllEqual(np.array([7, 7, 8], dtype=dtype), gather_nd_val)
    self.assertEqual([3], gather_nd_t.get_shape())
Esempio n. 22
0
 def testBadIndicesCPU(self):
   with self.test_session(use_gpu=False):
     params = [0, 1, 2]
     indices = [[[0], [7]]]  # Make this one higher rank
     gather_nd = array_ops.gather_nd(params, indices)
     with self.assertRaisesOpError(
         r"flat indices\[1, :\] = \[7\] does not index into param "
         r"\(shape: \[3\]\)"):
       gather_nd.eval()
Esempio n. 23
0
  def testGradientsRank2Slices(self):
    indices = constant_op.constant([[1], [0]], dtype=dtypes.int32)
    inputs = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64)
    outputs = array_ops.gather_nd(inputs, indices)

    grad_vals = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64)
    grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
    expected_grads = np.array([[3, 4], [1, 2]], dtype=np.float64)
    with self.test_session():
      self.assertAllEqual(expected_grads, grads.eval())
  def testGradientsRank2Elements(self):
    indices = constant_op.constant([[0, 0], [1, 1]], dtype=dtypes.int32)
    inputs = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64)
    outputs = array_ops.gather_nd(inputs, indices)

    grad_vals = constant_op.constant([1, 2], dtype=dtypes.float64)
    grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
    expected_grads = np.array([[1, 0], [0, 2]], dtype=np.float64)
    with self.session(use_gpu=True):
      assert np.array_equal(expected_grads, grads.eval())
Esempio n. 25
0
 def testGatherNdResourceVariable(self):
   with compat.forward_compatibility_horizon(2019, 4, 30):
     with self.cached_session():
       v = resource_variable_ops.ResourceVariable(
           constant_op.constant([[1, 2], [3, 4], [5, 6]]))
       self.evaluate(variables.global_variables_initializer())
       gather = array_ops.gather_nd(v, [[0, 1], [2, 0]])
       if not context.executing_eagerly():  # .op doesn't make sense in Eager
         self.assertEqual("ResourceGatherNd", gather.op.inputs[0].op.type)
       self.assertAllEqual([2, 5], gather)
Esempio n. 26
0
      def maybe_sample():
        """Perform scheduled sampling."""
        if self._next_input_layer is None:
          return array_ops.where(sample_ids, outputs, base_next_inputs)

        where_sampling = math_ops.cast(
            array_ops.where(sample_ids), dtypes.int32)
        where_not_sampling = math_ops.cast(
            array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32)
        outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
        inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
                                                  where_not_sampling)
        sampled_next_inputs = self._next_input_layer(outputs_sampling)
        base_shape = array_ops.shape(base_next_inputs)
        return (array_ops.scatter_nd(indices=where_sampling,
                                     updates=sampled_next_inputs,
                                     shape=base_shape)
                + array_ops.scatter_nd(indices=where_not_sampling,
                                       updates=inputs_not_sampling,
                                       shape=base_shape))
  def testHigherRankParams(self):
    with self.session(use_gpu=True):
      shape = (10, 20, 5, 1, 17)
      params = np.random.rand(*shape)
      indices = np.vstack([np.random.randint(0, s, size=2000) for s in shape]).T
      gather_nd_t = array_ops.gather_nd(params, indices)
      gather_nd_val = gather_nd_t.eval()

    expected = params[tuple(indices.T)]
    self.assertAllEqual(expected, gather_nd_val)
    self.assertEqual([2000], gather_nd_t.get_shape())
  def testGradientsRank2Slices(self):
    indices = constant_op.constant([[1], [0]], dtype=dtypes.int32)
    inputs = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64)
    outputs = array_ops.gather_nd(inputs, indices)

    grad_vals = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64)
    grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
    expected_grads = np.array([[3, 4], [1, 2]], dtype=np.float64)
    with self.session(use_gpu=True):
      self.assertIndexedSlices(grads)
      self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval())
Esempio n. 29
0
 def _single_seq_fn():
   batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0]
   example_inds = array_ops.reshape(
       math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
   sequence_scores = array_ops.gather_nd(
       array_ops.squeeze(inputs, [1]),
       array_ops.concat([example_inds, tag_indices], axis=1))
   sequence_scores = array_ops.where(math_ops.less_equal(sequence_lengths, 0),
                                     array_ops.zeros_like(sequence_scores),
                                     sequence_scores)
   return sequence_scores
 def _disabledTestBadIndicesWithSlicesGPU(self):
   # TODO disabled due to different behavior on GPU and CPU
   # On GPU the bad indices do not raise error but fetch 0 values
   if not test.is_gpu_available():
     return
   with self.session(use_gpu=True):
     params = [[0, 1, 2]]
     indices = [[[0], [0], [1]]]  # Make this one higher rank
     gather_nd = array_ops.gather_nd(params, indices)
     with self.assertRaisesOpError(
         r"indices\[0,2\] = \[1\] does not index into param shape \[1,3\]"):
       gather_nd.eval()
Esempio n. 31
0
 def call(self, inputs):
   indices = array_ops.where(math_ops.not_equal(inputs, 0))
   values = array_ops.gather_nd(inputs, indices)
   shape = array_ops.shape(inputs, out_type=dtypes.int64)
   return sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
Esempio n. 32
0
def _ScatterNdGrad(op, grad):
    indices = op.inputs[0]
    updates_grad = array_ops.gather_nd(grad, indices)
    return [None, updates_grad, None]
Esempio n. 33
0
def _SparseTensorDenseAddGrad(op, out_grad):
    sp_indices = op.inputs[0]
    # (sparse_indices, sparse_values, sparse_shape, dense)
    return (None, array_ops.gather_nd(out_grad, sp_indices), None, out_grad)
Esempio n. 34
0
def sparse_multiclass_hinge_loss(
        labels,
        logits,
        weights=1.0,
        scope=None,
        loss_collection=ops.GraphKeys.LOSSES,
        reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS):
    """Adds Ops for computing the multiclass hinge loss.

  The implementation is based on the following paper:
  On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines
  by Crammer and Singer.
  link: http://jmlr.csail.mit.edu/papers/volume2/crammer01a/crammer01a.pdf

  This is a generalization of standard (binary) hinge loss. For a given instance
  with correct label c*, the loss is given by:
    $$loss = max_{c != c*} logits_c - logits_{c*} + 1.$$
  or equivalently
    $$loss = max_c { logits_c - logits_{c*} + I_{c != c*} }$$
  where \\(I_{c != c*} = 1\ \text{if}\ c != c*\\) and 0 otherwise.

  Args:
    labels: `Tensor` of shape [batch_size] or [batch_size, 1]. Corresponds to
      the ground truth. Each entry must be an index in `[0, num_classes)`.
    logits: `Tensor` of shape [batch_size, num_classes] corresponding to the
      unscaled logits. Its dtype should be either `float32` or `float64`.
    weights: Optional (python) scalar or `Tensor`. If a non-scalar `Tensor`, its
      rank should be either 1 ([batch_size]) or 2 ([batch_size, 1]).
    scope: The scope for the operations performed in computing the loss.
    loss_collection: collection to which the loss will be added.
    reduction: Type of reduction to apply to loss.

  Returns:
    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
    shape as `labels`; otherwise, it is a scalar.

  Raises:
    ValueError: If `logits`, `labels` or `weights` have invalid or inconsistent
      shapes.
    ValueError: If `labels` tensor has invalid dtype.
  """

    with ops.name_scope(scope, 'sparse_multiclass_hinge_loss',
                        (logits, labels)) as scope:

        # Check logits Tensor has valid rank.
        logits_rank = logits.get_shape().ndims
        if logits_rank != 2:
            raise ValueError(
                'logits should have rank 2 ([batch_size, num_classes]). Given rank is'
                ' {}'.format(logits_rank))
        logits_shape = array_ops.shape(logits)
        batch_size, num_classes = logits_shape[0], logits_shape[1]
        logits = math_ops.to_float(logits)

        # Check labels have valid type.
        if labels.dtype != dtypes.int32 and labels.dtype != dtypes.int64:
            raise ValueError(
                'Invalid dtype for labels: {}. Acceptable dtypes: int32 and int64'
                .format(labels.dtype))

        # Check labels and weights have valid ranks and are consistent.
        labels_rank = labels.get_shape().ndims
        if labels_rank not in [1, 2]:
            raise ValueError(
                'labels should have rank 1 ([batch_size]) or 2 ([batch_size, 1]). '
                'Given rank is {}'.format(labels_rank))
        with ops.control_dependencies([
                check_ops.assert_less(labels,
                                      math_ops.cast(num_classes, labels.dtype))
        ]):
            labels = array_ops.reshape(labels, shape=[-1])

        weights = ops.convert_to_tensor(weights)
        weights_rank = weights.get_shape().ndims
        if weights_rank not in [0, 1, 2]:
            raise ValueError(
                'non-scalar weights should have rank 1 ([batch_size]) or 2 '
                '([batch_size, 1]). Given rank is {}'.format(labels_rank))

        if weights_rank > 0:
            weights = array_ops.reshape(weights, shape=[-1])
            # Check weights and labels have the same number of elements.
            weights.get_shape().assert_is_compatible_with(labels.get_shape())

        # Compute the logits tensor corresponding to the correct class per instance.
        example_indices = array_ops.reshape(math_ops.range(batch_size),
                                            shape=[batch_size, 1])
        indices = array_ops.concat([
            example_indices,
            array_ops.reshape(math_ops.cast(labels, example_indices.dtype),
                              shape=[batch_size, 1])
        ],
                                   axis=1)
        label_logits = array_ops.reshape(array_ops.gather_nd(params=logits,
                                                             indices=indices),
                                         shape=[batch_size, 1])

        one_cold_labels = array_ops.one_hot(indices=labels,
                                            depth=num_classes,
                                            on_value=0.0,
                                            off_value=1.0)
        margin = logits - label_logits + one_cold_labels
        margin = nn_ops.relu(margin)
        loss = math_ops.reduce_max(margin, axis=1)
        return losses.compute_weighted_loss(loss,
                                            weights,
                                            scope,
                                            loss_collection,
                                            reduction=reduction)
Esempio n. 35
0
  def _groupwise_dnn_v2(features, labels, mode, params, config):
    """Defines the dnn for groupwise scoring functions."""
    with ops.name_scope('transform'):
      context_features, per_example_features = _call_transform_fn(
          features, mode)

    def _score_fn(context_features, group_features, reuse):
      with variable_scope.variable_scope('group_score', reuse=reuse):
        return group_score_fn(context_features, group_features, mode, params,
                              config)

    # Scatter/Gather per-example scores through groupwise comparison. Each
    # instance in a mini-batch will form a number of groups. Each groups of
    # examples are scored by 'score_fn' and socres for individual examples
    # accumulated over groups.
    with ops.name_scope('groupwise_dnn_v2'):
      with ops.name_scope('infer_sizes'):
        if labels is not None:
          batch_size, list_size = array_ops.unstack(array_ops.shape(labels))
          is_valid = utils.is_label_valid(labels)
        else:
          # Infer batch_size and list_size from a feature.
          example_tensor_shape = array_ops.shape(
              next(six.itervalues(per_example_features)))
          batch_size = example_tensor_shape[0]
          list_size = example_tensor_shape[1]
          is_valid = utils.is_label_valid(
              array_ops.ones([batch_size, list_size]))
      if batch_size is None or list_size is None:
        raise ValueError(
            'Invalid batch_size=%s or list_size=%s' % (batch_size, list_size))

      # For each example feature, assume the shape is [batch_size, list_size,
      # feature_size], the groups are formed along the 2nd dim. Each group has a
      # 'group_size' number of indices in [0, list_size). Based on these
      # indices, we can gather the example feature into a sub-tensor for each
      # group. The total number of groups we have for a mini-batch is batch_size
      # * num_groups. Inside each group, we have a 'group_size' number of
      # examples.
      indices, mask = _form_group_indices_nd(
          is_valid, group_size,
          shuffle=(mode != model_fn.ModeKeys.PREDICT))
      num_groups = array_ops.shape(mask)[1]

      with ops.name_scope('group_features'):
        # For context features, We have shape [batch_size * num_groups, ...].
        large_batch_context_features = {}
        for name, value in six.iteritems(context_features):
          # [batch_size, 1, ...].
          value = array_ops.expand_dims(value, axis=1)
          # [batch_size, num_groups, ...].
          value = array_ops.gather(
              value, array_ops.zeros([num_groups], dtypes.int32), axis=1)
          # [batch_size * num_groups, ...]
          large_batch_context_features[name] = utils.reshape_first_ndims(
              value, 2, [batch_size * num_groups])

        # For example feature, we have shape [batch_size * num_groups,
        # group_size, ...].
        large_batch_group_features = {}
        for name, value in six.iteritems(per_example_features):
          # [batch_size, num_groups, group_size, ...].
          value = array_ops.gather_nd(value, indices)
          # [batch_size * num_groups, group_size, ...].
          large_batch_group_features[name] = utils.reshape_first_ndims(
              value, 3, [batch_size * num_groups, group_size])

      # Do the inference and get scores for the large batch.
      # [batch_size * num_groups, group_size].
      scores = _score_fn(
          large_batch_context_features, large_batch_group_features, reuse=False)

      with ops.name_scope('accumulate_scores'):
        scores = array_ops.reshape(scores, [batch_size, num_groups, group_size])
        # Reset invalid scores to 0 based on mask.
        scores = array_ops.where(
            array_ops.gather(
                array_ops.expand_dims(mask, 2),
                array_ops.zeros([group_size], dtypes.int32),
                axis=2), scores, array_ops.zeros_like(scores))
        # [batch_size, num_groups, group_size].
        list_scores = array_ops.scatter_nd(indices, scores,
                                           [batch_size, list_size])
        # Use average.
        list_scores /= math_ops.to_float(group_size)

    if mode == model_fn.ModeKeys.PREDICT:
      return list_scores
    else:
      features.update(context_features)
      features.update(per_example_features)
      return list_scores
Esempio n. 36
0
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
    """The reconstruct one or more matrices from their LU decomposition(s).

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
      matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
      X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_reconstruct').

  Returns:
    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
      `lu_reconstruct(*tf.linalg.lu(x))`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  x_reconstructed = tf.linalg.lu_reconstruct(*tf.linalg.lu(x))
  tf.assert_near(x, x_reconstructed)
  # ==> True
  ```

  """
    with ops.name_scope(name or 'lu_reconstruct'):
        lower_upper = ops.convert_to_tensor(lower_upper,
                                            dtype_hint=dtypes.float32,
                                            name='lower_upper')
        perm = ops.convert_to_tensor(perm,
                                     dtype_hint=dtypes.int32,
                                     name='perm')

        assertions = lu_reconstruct_assertions(lower_upper, perm,
                                               validate_args)
        if assertions:
            with ops.control_dependencies(assertions):
                lower_upper = array_ops.identity(lower_upper)
                perm = array_ops.identity(perm)

        shape = array_ops.shape(lower_upper)

        lower = set_diag(band_part(lower_upper, num_lower=-1, num_upper=0),
                         array_ops.ones(shape[:-1], dtype=lower_upper.dtype))
        upper = band_part(lower_upper, num_lower=0, num_upper=-1)
        x = math_ops.matmul(lower, upper)

        if (lower_upper.shape is None or lower_upper.shape.rank is None
                or lower_upper.shape.rank != 2):
            # We either don't know the batch rank or there are >0 batch dims.
            batch_size = math_ops.reduce_prod(shape[:-2])
            d = shape[-1]
            x = array_ops.reshape(x, [batch_size, d, d])
            perm = array_ops.reshape(perm, [batch_size, d])
            perm = map_fn.map_fn(array_ops.invert_permutation, perm)
            batch_indices = array_ops.broadcast_to(
                math_ops.range(batch_size)[:, array_ops.newaxis],
                [batch_size, d])
            x = array_ops.gather_nd(
                x, array_ops.stack([batch_indices, perm], axis=-1))
            x = array_ops.reshape(x, shape)
        else:
            x = array_ops.gather(x, array_ops.invert_permutation(perm))

        x.set_shape(lower_upper.shape)
        return x
Esempio n. 37
0
def _ScatterNdNonAliasingAddGrad(op, grad):
    indices = op.inputs[1]
    updates_grad = array_ops.gather_nd(grad, indices)
    return [grad, None, updates_grad]
Esempio n. 38
0
def scheduled_sampling_vocab_dist(hps,
                                  sampling_probability,
                                  output,
                                  embedding,
                                  inp,
                                  alpha=0):
    # borrowed ideas from https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTrainingHelper

    def soft_argmax(alpha, output):
        # alpha_exp = tf.exp(alpha * output) # (batch_size, vocab_size)
        # one_hot_scores = alpha_exp / tf.reshape(tf.reduce_sum(alpha_exp, axis=1),[-1,1]) #(batch_size, vocab_size)
        one_hot_scores = tf.nn.softmax(alpha * output)
        return one_hot_scores

    def soft_top_k(alpha, output, K):
        copy = tf.identity(output)
        p = []
        arg_top_k = []
        for k in range(K):
            sargmax = soft_argmax(alpha, copy)
            copy = (1 - sargmax) * copy
            p.append(tf.reduce_sum(sargmax * output, axis=1))
            arg_top_k.append(sargmax)

        return tf.stack(p, axis=1), tf.stack(arg_top_k)

    with variable_scope.variable_scope("ScheduledEmbedding"):
        # Return -1s where we did not sample, and sample_ids elsewhere
        select_sampler = bernoulli.Bernoulli(probs=sampling_probability,
                                             dtype=tf.bool)
        select_sample = select_sampler.sample(sample_shape=hps.batch_size)
        sample_id_sampler = categorical.Categorical(
            probs=output
        )  # equals to argmax{ Multinomial(output, total_count=1) }, our greedy search selection
        sample_ids = array_ops.where(select_sample,
                                     sample_id_sampler.sample(seed=123),
                                     gen_array_ops.fill([hps.batch_size], -1))

        where_sampling = math_ops.cast(array_ops.where(sample_ids > -1),
                                       tf.int32)
        where_not_sampling = math_ops.cast(array_ops.where(sample_ids <= -1),
                                           tf.int32)

        if hps.greedy_scheduled_sampling:
            sample_ids = tf.argmax(output, axis=1, output_type=tf.int32)

        sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)
        inputs_not_sampling = array_ops.gather_nd(inp, where_not_sampling)

        if hps.E2EBackProp:
            if hps.hard_argmax:
                greedy_search_prob, greedy_search_sample = tf.nn.top_k(
                    output, k=hps.k)  # (batch_size, k)
                greedy_search_prob_normalized = greedy_search_prob / tf.reshape(
                    tf.reduce_sum(greedy_search_prob, axis=1), [-1, 1])
                greedy_embedding = tf.nn.embedding_lookup(
                    embedding, greedy_search_sample)
                normalized_embedding = tf.multiply(
                    tf.reshape(greedy_search_prob_normalized,
                               [hps.batch_size, hps.k, 1]), greedy_embedding)
                e2e_embedding = tf.reduce_mean(normalized_embedding, axis=1)
            else:
                e = []
                greedy_search_prob, greedy_search_sample = soft_top_k(
                    alpha, output,
                    K=hps.k)  # (batch_size, k), (k, batch_size, vocab_size)
                greedy_search_prob_normalized = greedy_search_prob / tf.reshape(
                    tf.reduce_sum(greedy_search_prob, axis=1), [-1, 1])

                for _ in range(hps.k):
                    a_k = greedy_search_sample[_]
                    e_k = tf.matmul(
                        tf.reshape(greedy_search_prob_normalized[:, _],
                                   [-1, 1]) * a_k, embedding)
                    e.append(e_k)
                e2e_embedding = tf.reduce_sum(e,
                                              axis=0)  # (batch_size, emb_dim)
            sampled_next_inputs = array_ops.gather_nd(e2e_embedding,
                                                      where_sampling)
        else:
            if hps.hard_argmax:
                sampled_next_inputs = tf.nn.embedding_lookup(
                    embedding, sample_ids_sampling)
            else:  # using soft armax (greedy) proposed in: https://arxiv.org/abs/1704.06970
                # alpha_exp = tf.exp(alpha * (output_not_extended + G)) # (batch_size, vocab_size)
                # one_hot_scores = alpha_exp / tf.reduce_sum(alpha_exp, axis=1) #(batch_size, vocab_size)
                one_hot_scores = soft_argmax(
                    alpha, output)  # (batch_size, vocab_size)
                soft_argmax_embedding = tf.matmul(
                    one_hot_scores, embedding)  # (batch_size, emb_size)
                sampled_next_inputs = array_ops.gather_nd(
                    soft_argmax_embedding, where_sampling)

        base_shape = array_ops.shape(inp)
        result1 = array_ops.scatter_nd(indices=where_sampling,
                                       updates=sampled_next_inputs,
                                       shape=base_shape)
        result2 = array_ops.scatter_nd(indices=where_not_sampling,
                                       updates=inputs_not_sampling,
                                       shape=base_shape)
        return result1 + result2
def gather_tree_from_array(t, parent_ids, sequence_length):
    """Calculates the full beams for `TensorArray`s.

  Args:
    t: A stacked `TensorArray` of size `max_time` that contains `Tensor`s of
      shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]`
      where `s` is the depth shape.
    parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`.
    sequence_length: The sequence length of shape `[batch_size, beam_width]`.

  Returns:
    A `Tensor` which is a stacked `TensorArray` of the same size and type as
    `t` and where beams are sorted in each `Tensor` according to `parent_ids`.
  """
    max_time = parent_ids.shape[0].value or array_ops.shape(parent_ids)[0]
    batch_size = parent_ids.shape[1].value or array_ops.shape(parent_ids)[1]
    beam_width = parent_ids.shape[2].value or array_ops.shape(parent_ids)[2]

    # Generate beam ids that will be reordered by gather_tree.
    beam_ids = array_ops.expand_dims(
        array_ops.expand_dims(math_ops.range(beam_width), 0), 0)
    beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1])

    mask = array_ops.sequence_mask(sequence_length,
                                   maxlen=max_time,
                                   dtype=dtypes.int32)
    mask = array_ops.transpose(mask, perm=[2, 0, 1])

    # Use beam_width + 1 to mark the end of beam.
    masked_beam_ids = (beam_ids * mask) + (1 - mask) * (beam_width + 1)

    max_sequence_lengths = math_ops.to_int32(
        math_ops.reduce_max(sequence_length, axis=1))
    sorted_beam_ids = beam_search_ops.gather_tree(
        step_ids=masked_beam_ids,
        parent_ids=parent_ids,
        max_sequence_lengths=max_sequence_lengths,
        end_token=beam_width + 1)

    # For out of range steps, simply copy the same beam.
    sorted_beam_ids = array_ops.where(math_ops.cast(mask, dtypes.bool),
                                      x=sorted_beam_ids,
                                      y=beam_ids)

    # Generate indices for gather_nd.
    time_ind = array_ops.tile(
        array_ops.reshape(math_ops.range(max_time), [-1, 1, 1]),
        [1, batch_size, beam_width])
    batch_ind = array_ops.tile(
        array_ops.reshape(math_ops.range(batch_size), [-1, 1, 1]),
        [1, max_time, beam_width])
    batch_ind = array_ops.transpose(batch_ind, perm=[1, 0, 2])
    indices = array_ops.stack([time_ind, batch_ind, sorted_beam_ids], -1)

    # Gather from a tensor with collapsed additional dimensions.
    gather_from = t
    final_shape = array_ops.shape(gather_from)
    gather_from = array_ops.reshape(gather_from,
                                    [max_time, batch_size, beam_width, -1])
    ordered = array_ops.gather_nd(gather_from, indices)
    ordered = array_ops.reshape(ordered, final_shape)

    return ordered
 def _gather_states(self, data, indices, batch_size):
     """Produce `out`, s.t. out(i, j) = data(indices(i), i, j)."""
     return array_ops.gather_nd(
         data,
         array_ops.stack([indices, math_ops.range(batch_size)], axis=1))
Esempio n. 41
0
 def dense_to_sparse_non_scalar(tensor):
     indices = array_ops.where(
         array_ops.ones_like(tensor, dtype=dtypes.bool))
     values = array_ops.gather_nd(tensor, indices)
     shape = array_ops.shape(tensor, out_type=dtypes.int64)
     return sparse_tensor.SparseTensorValue(indices, values, shape)
Esempio n. 42
0
def fill_lower_triangular(x,
                          validate_args=False,
                          name="fill_lower_triangular"):
    """Creates a (batch of) lower triangular matrix from a vector of inputs.

  If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
  b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
  `n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`.

  Although the non-batch complexity is O(n^2), large constants and sub-optimal
  vectorization means the complexity of this function is 5x slower than zeroing
  out the upper triangular, i.e., `tf.matrix_band_part(X, -1, 0)`.  This
  function becomes competitive only when several matmul/cholesky/etc ops can be
  ellided in constructing the input.  Example: wiring a fully connected layer as
  a covariance matrix; this function reduces the final layer by 2x and possibly
  reduces the network arch complexity considerably.  In most cases it is better
  to simply build a full matrix and zero out the upper triangular elements,
  e.g., `tril = tf.matrix_band_part(full, -1, 0)`, rather than directly
  construct a lower triangular.

  Example:

  ```python
  fill_lower_triangular([1, 2, 3, 4, 5, 6])
  # Returns: [[1, 0, 0],
  #           [2, 3, 0],
  #           [4, 5, 6]]
  ```

  For comparison, a pure numpy version of this function can be found in
  `distribution_util_test.py`, function `_fill_lower_triangular`.

  Args:
    x: `Tensor` representing lower triangular elements.
    validate_args: `Boolean`, default `False`.  Whether to ensure the shape of
      `x` can be mapped to a lower triangular matrix (controls non-static checks
      only).
    name: `String`. The name to give this op.

  Returns:
    tril: `Tensor` with lower triangular elements filled from `x`.

  Raises:
    ValueError: if shape if `x` has static shape which cannot be mapped to a
      lower triangular matrix.
  """
    # TODO(jvdillon): Replace this code with dedicated op when it exists.
    with ops.name_scope(name, values=(x, )):
        x = ops.convert_to_tensor(x, name="x")
        if (x.get_shape().ndims is not None
                and x.get_shape()[-1].value is not None):
            d = x.get_shape()[-1].value
            # d = n(n+1)/2 implies n is:
            n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
            d_inferred = n * (n + 1) / 2
            if d != d_inferred:
                raise ValueError(
                    "Input cannot be mapped to a lower triangular; "
                    "n*(n+1)/2 = %d != %d" % (d_inferred, d))
            final_shape = x.get_shape()[:-1].concatenate(
                tensor_shape.TensorShape([n, n]))
        else:
            d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
            # d = n(n+1)/2 implies n is:
            n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
                              dtype=dtypes.int32)
            if validate_args:
                is_valid_input_shape = check_ops.assert_equal(
                    n * (n + 1) / 2,
                    d,
                    message="Input cannot be mapped to a lower triangular.")
                n = control_flow_ops.with_dependencies([is_valid_input_shape],
                                                       n)
            final_shape = x.get_shape()[:-1].concatenate(
                tensor_shape.TensorShape([None, None]))

        def tril_ids(n):
            """Internal helper to create vector of linear indices into y."""
            # Build the ids statically; chose 512 because it implies 1MiB.
            if not contrib_framework.is_tensor(n) and n <= 512:
                ids = np.arange(n**2, dtype=np.int32)
                rows = (ids / n).astype(np.int32)  # Implicit floor.
                # We need to stop incrementing the index when we encounter
                # upper-triangular elements.  The idea here is to compute the
                # lower-right number of zeros then by "symmetry" subtract this from the
                # total number of zeros, n(n-1)/2.
                # Then we note that: n(n-1)/2 - (n-r)*(n-r-1)/2 = r(2n-r-1)/2
                offset = (rows * (2 * n - rows - 1) / 2).astype(np.int32)
                # We could also zero out when (rows < cols) == (rows < ids-n*rows).
                # mask = (ids <= (n + 1) * rows).astype(np.int32)
            else:
                ids = math_ops.range(n**2)
                rows = math_ops.cast(ids / n, dtype=dtypes.int32)
                offset = math_ops.cast(rows * (2 * n - rows - 1) / 2,
                                       dtype=dtypes.int32)
            return ids - offset

        # Special-case non-batch case.
        if x.get_shape().ndims == 1:
            y = array_ops.gather(x, array_ops.reshape(tril_ids(n), [n, n]))
            y = array_ops.matrix_band_part(y, -1, 0)
            y.set_shape(y.get_shape().merge_with(final_shape))
            return y

        # Make ids for each batch dim.
        if (x.get_shape().ndims is not None
                and x.get_shape()[:-1].is_fully_defined()):
            batch_shape = np.asarray(x.get_shape()[:-1].as_list(),
                                     dtype=np.int32)
            m = np.prod(batch_shape).astype(np.int32)
        else:
            batch_shape = array_ops.shape(x)[:-1]
            m = array_ops.reduce_prod(array_ops.shape(x)[:-1])
        batch_ids = math_ops.range(m)

        # Assemble the tril_ids into batch,tril_id pairs.
        idx = array_ops.stack([
            array_ops.tile(array_ops.expand_dims(batch_ids, 1), [1, n * n]),
            array_ops.tile(array_ops.expand_dims(tril_ids(n), 0), [m, 1])
        ])
        idx = array_ops.transpose(idx, [1, 2, 0])

        # Gather up, reshape, and return.
        y = array_ops.reshape(x, [-1, d])
        y = array_ops.gather_nd(y, idx)
        y = array_ops.reshape(y, array_ops.concat_v2([batch_shape, [n, n]], 0))
        y = array_ops.matrix_band_part(y, -1, 0)
        y.set_shape(y.get_shape().merge_with(final_shape))
        return y
Esempio n. 43
0
def _list_mle_loss(labels,
                   logits,
                   weights=None,
                   lambda_weight=None,
                   reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
                   name=None,
                   seed=None):
    """Computes the ListMLE loss [Xia et al.

  2008] for a list.

  Given the labels of graded relevance l_i and the logits s_i, we calculate
  the ListMLE loss for the given list.

  The `lambda_weight` re-weights examples based on l_i and r_i.
  The recommended weighting scheme is the formulation presented in the
  "Position-Aware ListMLE" paper (Lan et. al) and available using
  create_p_list_mle_lambda_weight() factory function above.

  Args:
    labels: A `Tensor` of the same shape as `logits` representing graded
      relevance.
    logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
      ranking score of the corresponding item.
    weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
      weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
      weights.
    lambda_weight: A `DCGLambdaWeight` instance.
    reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
      reduce training loss over batch.
    name: A string used as the name for this loss.
    seed: A randomization seed used when shuffling ground truth permutations.

  Returns:
    An op for the ListMLE loss.
  """
    with ops.name_scope(name, 'list_mle_loss', (labels, logits, weights)):
        is_label_valid = utils.is_label_valid(labels)
        # Reset the invalid labels to 0 and reset the invalid logits to a logit with
        # ~= 0 contribution.
        labels = array_ops.where(is_label_valid, labels,
                                 array_ops.zeros_like(labels))
        logits = array_ops.where(
            is_label_valid, logits,
            math_ops.log(_EPSILON) * array_ops.ones_like(logits))
        weights = 1.0 if weights is None else ops.convert_to_tensor(weights)
        weights = array_ops.squeeze(weights)

        # Shuffle labels and logits to add randomness to sort.
        shuffled_indices = utils.shuffle_valid_indices(is_label_valid, seed)
        shuffled_labels = array_ops.gather_nd(labels, shuffled_indices)
        shuffled_logits = array_ops.gather_nd(logits, shuffled_indices)

        sorted_labels, sorted_logits = utils.sort_by_scores(
            shuffled_labels, [shuffled_labels, shuffled_logits])

        raw_max = math_ops.reduce_max(sorted_logits, axis=1, keepdims=True)
        sorted_logits = sorted_logits - raw_max
        sums = math_ops.cumsum(math_ops.exp(sorted_logits),
                               axis=1,
                               reverse=True)
        sums = math_ops.log(sums) - sorted_logits

        if lambda_weight is not None and isinstance(lambda_weight,
                                                    ListMLELambdaWeight):
            sums *= lambda_weight.individual_weights(sorted_labels)

        negative_log_likelihood = math_ops.reduce_sum(sums, 1)

        return core_losses.compute_weighted_loss(negative_log_likelihood,
                                                 weights=weights,
                                                 reduction=reduction)
Esempio n. 44
0
def _TensorScatterSubGrad(op, grad):
    indices = op.inputs[1]
    updates_grad = array_ops.gather_nd(grad, indices)
    tensor_grad = array_ops.identity(grad)
    return [tensor_grad, None, -updates_grad]
Esempio n. 45
0
def tridiagonal_solve(diagonals,
                      rhs,
                      diagonals_format='compact',
                      transpose_rhs=False,
                      conjugate_rhs=False,
                      name=None):
    r"""Solves tridiagonal systems of equations.

  Solution is computed via Gaussian elemination with partial pivoting.

  The input can be supplied in various formats: `matrix`, `tuple` and `compact`,
  specified by the `diagonals_format` arg.

  In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
  two inner-most dimensions representing the square tridiagonal matrices.
  Elements outside of the three diagonals will be ignored.

  In `sequence` format, `diagonals` are supplied as a tuple or list of three
  tensors of shapes `[..., N]`, `[..., M]`, `[..., N]` representing
  superdiagonals, diagonals, and subdiagonals, respectively. `N` can be either
  `M-1` or `M`; in the latter case, the last element of superdiagonal and the
  first element of subdiagonal will be ignored.

  In `compact` format the three diagonals are brought together into one tensor
  of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
  diagonals, and subdiagonals, in order. Similarly to `sequence` format,
  elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.

  The `compact` format is recommended as the one with best performance. In case
  you need to cast a tensor into a compact format manually, use `tf.gather_nd`.
  An example for a tensor of shape [m, m]:

  ```python
  rhs = tf.constant([...])
  matrix = tf.constant([[...]])
  m = matrix.shape[0]
  dummy_idx = [0, 0]  # An arbitrary element to use as a dummy
  indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx],  # Superdiagonal
           [[i, i] for i in range(m)],                          # Diagonal
           [dummy_idx] + [[i + 1, i] for i in range(m - 1)]]    # Subdiagonal
  diagonals=tf.gather_nd(matrix, indices)
  x = tf.linalg.tridiagonal_solve(diagonals, rhs)
  ```

  Regardless of the `diagonals_format`, `rhs` is a tensor of shape `[..., M]` or
  `[..., M, K]`. The latter allows to simultaneously solve K systems with the
  same left-hand sides and K different right-hand sides. If `transpose_rhs`
  is set to `True` the expected shape is `[..., M]` or `[..., K, M]`.

  The batch dimensions, denoted as `...`, must be the same in `diagonals` and
  `rhs`.

  The output is a tensor of the same shape as `rhs`: either `[..., M]` or
  `[..., M, K]`.

  The op isn't guaranteed to raise an error if the input matrix is not
  invertible. `tf.debugging.check_numerics` can be applied to the output to
  detect invertibility problems.

  Args:
    diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
      shape depends of `diagonals_format`, see description above. Must be
      `float32`, `float64`, `complex64`, or `complex128`.
    rhs: A `Tensor` of shape [..., M] or [..., M, K] and with the same dtype as
      `diagonals`.
    diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is
      `compact`.
    transpose_rhs: If `True`, `rhs` is transposed before solving (has no effect
      if the shape of rhs is [..., M]).
    conjugate_rhs: If `True`, `rhs` is conjugated before solving.
    name:  A name to give this `Op` (optional).

  Returns:
    A `Tensor` of shape [..., M] or [..., M, K] containing the solutions.

  Raises:
    ValueError: An unsupported type is provided as input, or when the input
    tensors have incorrect shapes.

  """
    if diagonals_format == 'compact':
        return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
                                                 conjugate_rhs, name)

    if diagonals_format == 'sequence':
        if not isinstance(diagonals, (tuple, list)) or len(diagonals) != 3:
            raise ValueError(
                'Expected diagonals to be a sequence of length 3.')

        superdiag, maindiag, subdiag = diagonals
        if (not subdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1])
                or not superdiag.shape[:-1].is_compatible_with(
                    maindiag.shape[:-1])):
            raise ValueError(
                'Tensors representing the three diagonals must have the same shape,'
                'except for the last dimension, got {}, {}, {}'.format(
                    subdiag.shape, maindiag.shape, superdiag.shape))

        m = tensor_shape.dimension_value(maindiag.shape[-1])

        def pad_if_necessary(t, name, last_dim_padding):
            n = tensor_shape.dimension_value(t.shape[-1])
            if not n or n == m:
                return t
            if n == m - 1:
                paddings = ([[0, 0] for _ in range(len(t.shape) - 1)] +
                            [last_dim_padding])
                return array_ops.pad(t, paddings)
            raise ValueError(
                'Expected {} to be have length {} or {}, got {}.'.format(
                    name, m, m - 1, n))

        subdiag = pad_if_necessary(subdiag, 'subdiagonal', [1, 0])
        superdiag = pad_if_necessary(superdiag, 'superdiagonal', [0, 1])

        diagonals = array_ops.stack((superdiag, maindiag, subdiag), axis=-2)
        return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
                                                 conjugate_rhs, name)

    if diagonals_format == 'matrix':
        m1 = tensor_shape.dimension_value(diagonals.shape[-1])
        m2 = tensor_shape.dimension_value(diagonals.shape[-2])
        if m1 and m2 and m1 != m2:
            raise ValueError(
                'Expected last two dimensions of diagonals to be same, got {} and {}'
                .format(m1, m2))
        m = m1 or m2
        if not m:
            raise ValueError('The size of the matrix needs to be known for '
                             'diagonals_format="matrix"')

        # Extract diagonals; use input[..., 0, 0] as "dummy" m-th elements of sub-
        # and superdiagonal.
        # gather_nd slices into first indices, whereas we need to slice into the
        # last two, so transposing back and forth is necessary.
        dummy_idx = [0, 0]
        indices = ([[[1, 0], [0, 0], dummy_idx]] +
                   [[[i + 1, i], [i, i], [i - 1, i]]
                    for i in range(1, m - 1)] +
                   [[dummy_idx, [m - 1, m - 1], [m - 2, m - 1]]])
        diagonals = array_ops.transpose(
            array_ops.gather_nd(array_ops.transpose(diagonals), indices))
        return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
                                                 conjugate_rhs, name)

    raise ValueError(
        'Unrecognized diagonals_format: {}'.format(diagonals_format))
Esempio n. 46
0
def gather_nd(params, indices, batch_dims=0, name=None):
    """Gather slices from `params` using `n`-dimensional indices.

  This operation is similar to `gather`, but it uses the innermost dimension
  of `indices` to define a slice into `params`.  In particular, if:

  * `indices` has shape `[A1...AN, I]`
  * `params` has shape `[B1...BM]`

  Then:

  * `result` has shape `[A1...AN, B_{I+1}...BM]`.
  * `result[a1...aN] = params[indices[a1...aN, :]]`

  Args:
    params: A potentially ragged tensor with shape `[A1...AN, I]`.
    indices: A potentially ragged tensor with shape `[B1...BM]`.
    batch_dims: Must be zero.
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[A1...AN, B_{I+1}...BM]`.

  #### Examples:
    ```python
    >>> params = tf.compat.v1.ragged.constant_value(
    ...     [ [ ['000', '001'], ['010'              ]          ],
    ...       [ ['100'       ], ['110', '111', '112'], ['120'] ],
    ...       [ [            ], ['210'              ]          ] ])

    >>> # Gather 2D slices from a 3D tensor
    >>> ragged.gather_nd(params, [[2], [0]])
    [ [ [            ], ['210'] ]
      [ ['000', '001'], ['010'] ] ]

    >>> # Gather 1D slices from a 3D tensor
    >>> ragged.gather_nd(params, [[2, 1], [0, 0]])
    [['210'], ['000', '001']]

    >>> # Gather scalars from a 3D tensor
    >>> ragged.gather_nd(params, [[0, 0, 1], [1, 1, 2]])
    ['001', '112']
    ```
  """
    if not isinstance(batch_dims, int) or batch_dims != 0:
        raise ValueError(
            'batch_dims != 0 is not supported for ragged gather yet.')
    if not (ragged_tensor.is_ragged(params)
            or ragged_tensor.is_ragged(indices)):
        return array_ops.gather_nd(params, indices, name)

    with ops.name_scope(name, 'RaggedGatherNd', [params, indices]):

        params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            params, name='params')
        indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            indices, name='indices')
        params, indices = ragged_tensor.match_row_splits_dtypes(
            params, indices)
        indices_shape = indices.shape
        indices_ndims = indices_shape.ndims
        if indices_ndims is None:
            raise ValueError('indices.rank be statically known.')
        if indices_ndims == 0:
            raise ValueError('indices.rank must be at least 1.')
        if (ragged_tensor.is_ragged(indices)
                and indices_ndims == indices.ragged_rank + 1):
            raise ValueError(
                'The innermost dimension of indices may not be ragged')

        # `index_size` is the "n" in "gather_nd" -- i.e., the number of dimensions
        # that each index slices into.
        index_size = tensor_shape.dimension_value(indices_shape[-1])
        if index_size is None:
            raise ValueError('indices.shape[-1] must be statically known.')

        # If `indices` has more than 2 dimensions, then recurse.  If `indices` is
        # dense, then we convert it to ragged before recursing, and then convert
        # the result back to `dense` if appropriate.
        if indices_ndims > 2:
            indices_is_dense = not ragged_tensor.is_ragged(indices)
            if indices_is_dense:
                indices = ragged_tensor.RaggedTensor.from_tensor(
                    indices,
                    ragged_rank=indices_ndims - 2,
                    row_splits_dtype=params.row_splits.dtype)
            result = indices.with_flat_values(
                gather_nd(params, indices.flat_values))
            if (indices_is_dense and ragged_tensor.is_ragged(result)
                    and result.ragged_rank == indices_ndims - 2):
                result = ragged_tensor.RaggedTensor.to_tensor(result)
            return result

        # indices_ndims <= 2, and the innermost dimension of indices may not be
        # ragged, so `indices` must not be ragged.
        assert not ragged_tensor.is_ragged(indices)
        assert ragged_tensor.is_ragged(params)

        # Handle corner case: An empty index tuple selects the entire `params`
        # value.  So if `index_size` is zero, then tile `params`.
        if index_size == 0:
            params_ndims = params.ragged_rank + array_ops.rank(
                params.flat_values)
            for dim in range(indices_ndims - 1):
                params = ragged_array_ops.expand_dims(params, axis=0)
            multiples = array_ops.concat([
                array_ops.shape(indices)[:-1],
                array_ops.ones([params_ndims], dtypes.int32)
            ],
                                         axis=0)
            return ragged_array_ops.tile(params, multiples)

        # When index_size=1, we can just flatten the index tuples and use gather.
        elif index_size == 1:
            flattened_index_tuples = array_ops.reshape(indices, [-1])
            return gather(params, flattened_index_tuples)

        # Otherwise, params is a RaggedTensor, and indices is a 1D or 2D Tensor.
        # Flatten both the index tuples and the params, such that the flattened
        # index tuples point to the correct values in the flattened params; and
        # then use ragged.gather on the flattened index tuples & params.
        else:
            indices = math_ops.cast(indices, params.row_splits.dtype)

            # Flatten the outermost 2 dimensions of the index tuples & params.
            flattened_index_tuples = array_ops.gather(params.row_splits,
                                                      indices[..., 0])
            flattened_index_tuples += indices[..., 1]
            flattened_params = params.values

            # Flatten any remaining dimensions.
            for dim in range(2, index_size):
                if not ragged_tensor.is_ragged(flattened_params):
                    flattened_index_tuples = array_ops.expand_dims(
                        flattened_index_tuples, axis=1)
                    flattened_index_tuples = array_ops.concat(
                        [flattened_index_tuples, indices[..., dim:]], axis=1)
                    return array_ops.gather_nd(flattened_params,
                                               flattened_index_tuples)

                flattened_index_tuples = array_ops.gather(
                    flattened_params.row_starts(), flattened_index_tuples)
                flattened_index_tuples += indices[..., dim]
                flattened_params = flattened_params.values

            # Gather using the flattened index tuples and params.
            return gather(flattened_params, flattened_index_tuples)
Esempio n. 47
0
def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None):
    r"""Multiplies tridiagonal matrix by matrix.

  `diagonals` is representation of 3-diagonal NxN matrix, which depends on
  `diagonals_format`.

  In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
  two inner-most dimensions representing the square tridiagonal matrices.
  Elements outside of the three diagonals will be ignored.

  If `sequence` format, `diagonals` is list or tuple of three tensors:
  `[superdiag, maindiag, subdiag]`, each having shape [..., M]. Last element
  of `superdiag` first element of `subdiag` are ignored.

  In `compact` format the three diagonals are brought together into one tensor
  of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
  diagonals, and subdiagonals, in order. Similarly to `sequence` format,
  elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.

  The `sequence` format is recommended as the one with the best performance.

  `rhs` is matrix to the right of multiplication. It has shape `[..., M, N]`.

  Example:

  ```python
  superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
  maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
  subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
  diagonals = [superdiag, maindiag, subdiag]
  rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64)
  x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')
  ```

  Args:
    diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
      shape depends of `diagonals_format`, see description above. Must be
      `float32`, `float64`, `complex64`, or `complex128`.
    rhs: A `Tensor` of shape [..., M, N] and with the same dtype as `diagonals`.
    diagonals_format: one of `sequence`, or `compact`. Default is `compact`.
    name:  A name to give this `Op` (optional).

  Returns:
    A `Tensor` of shape [..., M, N] containing the result of multiplication.

  Raises:
    ValueError: An unsupported type is provided as input, or when the input
    tensors have incorrect shapes.
  """
    if diagonals_format == 'compact':
        superdiag = diagonals[..., 0, :]
        maindiag = diagonals[..., 1, :]
        subdiag = diagonals[..., 2, :]
    elif diagonals_format == 'sequence':
        superdiag, maindiag, subdiag = diagonals
    elif diagonals_format == 'matrix':
        m1 = tensor_shape.dimension_value(diagonals.shape[-1])
        m2 = tensor_shape.dimension_value(diagonals.shape[-2])
        if not m1 or not m2:
            raise ValueError('The size of the matrix needs to be known for '
                             'diagonals_format="matrix"')
        if m1 != m2:
            raise ValueError(
                'Expected last two dimensions of diagonals to be same, got {} and {}'
                .format(m1, m2))

        # TODO(b/131695260): use matrix_diag_part when it supports extracting
        # arbitrary diagonals.
        maindiag = array_ops.matrix_diag_part(diagonals)
        diagonals = array_ops.transpose(diagonals)
        dummy_index = [0, 0]
        superdiag_indices = [[i + 1, i]
                             for i in range(0, m1 - 1)] + [dummy_index]
        subdiag_indices = [dummy_index] + [[i - 1, i] for i in range(1, m1)]
        superdiag = array_ops.transpose(
            array_ops.gather_nd(diagonals, superdiag_indices))
        subdiag = array_ops.transpose(
            array_ops.gather_nd(diagonals, subdiag_indices))
    else:
        raise ValueError('Unrecognized diagonals_format: %s' %
                         diagonals_format)

    # C++ backend requires matrices.
    # Converting 1-dimensional vectors to matrices with 1 row.
    superdiag = array_ops.expand_dims(superdiag, -2)
    maindiag = array_ops.expand_dims(maindiag, -2)
    subdiag = array_ops.expand_dims(subdiag, -2)

    return linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs,
                                          name)
Esempio n. 48
0
def fill_lower_triangular(x, name="fill_lower_triangular"):
    """Creates a (batch of) lower triangular matrix from a vector of inputs.

  If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
  b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
  `n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`.

  Note: This function is very slow; possibly 10x slower than zero-ing out the
  upper-triangular portion of a full matrix.

  Example:

  ```python
  fill_lower_triangular([1, 2, 3, 4, 5, 6])
  # Returns: [[1, 0, 0],
  #           [2, 3, 0],
  #           [4, 5, 6]]
  ```

  Args:
    x: `Tensor` representing lower triangular elements.
    name: `String`. The name to give this op.

  Returns:
    tril: `Tensor` with lower triangular elements filled from `x`.
  """
    with ops.name_scope(name, values=(x, )):
        x = ops.convert_to_tensor(x, name="x")
        ndims = x.get_shape().ndims
        if ndims is not None and x.get_shape()[-1].value is not None:
            d = x.get_shape()[-1].value
            # d = n^2/2 + n/2 implies n is:
            n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
            final_shape = x.get_shape()[:-1].concatenate(
                tensor_shape.TensorShape([n, n]))
        else:
            ndims = array_ops.rank(x)
            d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
            # d = n^2/2 + n/2 implies n is:
            n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
                              dtype=dtypes.int32)
            final_shape = x.get_shape()[:-1].concatenate(
                tensor_shape.TensorShape([None, None]))

        # Make ids for each batch dim.
        if (x.get_shape().ndims is not None
                and x.get_shape()[:-1].is_fully_defined()):
            batch_shape = np.asarray(x.get_shape()[:-1].as_list(),
                                     dtype=np.int32)
            m = np.prod(batch_shape)
        else:
            batch_shape = array_ops.shape(x)[:-1]
            m = array_ops.reduce_prod(batch_shape)

        # Flatten batch dims.
        y = array_ops.reshape(x, [-1, d])

        # Prepend a zero to each row.
        y = array_ops.pad(y, paddings=[[0, 0], [1, 0]])

        # Make ids for each batch dim.
        if x.get_shape()[:-1].is_fully_defined():
            m = np.asarray(np.prod(x.get_shape()[:-1].as_list()),
                           dtype=np.int32)
        else:
            m = array_ops.reduce_prod(array_ops.shape(x)[:-1])
        batch_ids = math_ops.range(m)

        def make_tril_ids(n):
            """Internal helper to create vector of linear indices into y."""
            cols = array_ops.reshape(array_ops.tile(math_ops.range(n), [n]),
                                     [n, n])
            rows = array_ops.tile(array_ops.expand_dims(math_ops.range(n), -1),
                                  [1, n])
            pred = math_ops.greater(cols, rows)
            tril_ids = array_ops.tile(
                array_ops.reshape(math_ops.cumsum(math_ops.range(n)), [n, 1]),
                [1, n]) + cols
            tril_ids = math_ops.select(
                pred, array_ops.zeros([n, n], dtype=dtypes.int32),
                tril_ids + 1)
            tril_ids = array_ops.reshape(tril_ids, [-1])
            return tril_ids

        tril_ids = make_tril_ids(n)

        # Assemble the ids into pairs.
        idx = array_ops.pack([
            array_ops.tile(array_ops.expand_dims(batch_ids, -1), [1, n * n]),
            array_ops.tile([tril_ids], [m, 1])
        ])
        idx = array_ops.transpose(idx, [1, 2, 0])

        y = array_ops.gather_nd(y, idx)
        y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))

        y.set_shape(y.get_shape().merge_with(final_shape))

        return y