예제 #1
0
def _TopKGrad(op, grad, _):
  """Return the gradients for TopK.

  Args:
    op: The TopKOp for which we need to generate gradients.
    grad: Tensor. The gradients passed to the TopKOp.

  Returns:
    A list of two tensors, the first being the gradient w.r.t to the input and
    TopK, and the second being the gradient w.r.t. to the indices (all zero).
  """
  in_shape = array_ops.shape(op.inputs[0])
  ind_shape = array_ops.shape(op.outputs[1])

  ind_lastdim = array_ops.gather(ind_shape, array_ops.size(ind_shape) - 1)
  # Flatten indices to 2D.
  ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim]))

  in_lastdim = array_ops.gather(in_shape, array_ops.size(in_shape) - 1)
  outerdim = array_ops.shape(ind_2d)[0]
  # Compute linear indices (flattened to 1D).
  ind = array_ops.reshape(ind_2d + array_ops.expand_dims(
      math_ops.range(0, outerdim * in_lastdim, in_lastdim), -1), [-1])

  # Substitute grad to appropriate locations and fill the rest with zeros,
  # finally reshaping it to the original input shape.
  return [array_ops.reshape(
      sparse_ops.sparse_to_dense(ind,
                                 array_ops.reshape(
                                     math_ops.reduce_prod(in_shape), [1]),
                                 array_ops.reshape(grad, [-1]),
                                 validate_indices=False),
      in_shape), array_ops.zeros(
          [], dtype=dtypes.int32)]
예제 #2
0
 def testInstantError(self):
   if context.num_gpus():
     # TODO(nareshmodi): make this test better
     self.skipTest("Gather doesn't do index checking on GPUs")
   with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                r'indices = 7 is not in \[0, 3\)'):
     array_ops.gather([0, 1, 2], 7)
  def quantiles_ready():
    """The subgraph for when the quantiles are ready."""
    quantized_feature = quantile_ops.quantiles([sparse_column_values], [],
                                               [quantile_buckets], [])
    quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64)
    quantized_feature = array_ops.reshape(quantized_feature, [-1])
    example_indices, _ = array_ops.split(
        sparse_column_indices, num_or_size_splits=2, axis=1)
    example_indices = array_ops.squeeze(example_indices, [1])
    filtered_gradients = array_ops.gather(gradients, example_indices)
    filtered_hessians = array_ops.gather(hessians, example_indices)
    filtered_partition_ids = array_ops.gather(example_partition_ids,
                                              example_indices)
    unique_partitions, mapped_partitions = array_ops.unique(
        example_partition_ids)

    # Compute aggregate stats for each partition.
    per_partition_gradients = math_ops.unsorted_segment_sum(
        gradients, mapped_partitions, array_ops.size(unique_partitions))
    per_partition_hessians = math_ops.unsorted_segment_sum(
        hessians, mapped_partitions, array_ops.size(unique_partitions))

    # Prepend a bias feature per partition that accumulates the stats for all
    # examples in that partition.
    bias_feature_ids = array_ops.fill(
        array_ops.shape(unique_partitions), _BIAS_FEATURE_ID)
    bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64)
    partition_ids = array_ops.concat(
        [unique_partitions, filtered_partition_ids], 0)
    filtered_gradients = array_ops.concat(
        [per_partition_gradients, filtered_gradients], 0)
    filtered_hessians = array_ops.concat(
        [per_partition_hessians, filtered_hessians], 0)
    bucket_ids = array_ops.concat([bias_feature_ids, quantized_feature], 0)
    return partition_ids, bucket_ids, filtered_gradients, filtered_hessians
예제 #4
0
 def _apply_sparse_shared(self, grad, var, indices,
                          scatter_add, scatter_update):
   beta1_power = self._get_beta_accumulators()
   beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
   lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
   beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
   beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
   epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
   # m_t = beta1 * m + (1 - beta1) * g_t
   m = self.get_slot(var, "m")
   m_slice = array_ops.gather(m, indices)
   m_t_slice = m_slice * beta1_t + grad * (1 - beta1_t)
   with ops.control_dependencies([m_t_slice]):
     m_t = scatter_update(m, indices, m_t_slice)
   # u_t = max(beta2 * u, abs(g_t))
   v = self.get_slot(var, "v")
   v_slice = array_ops.gather(v, indices)
   v_t_slice = math_ops.maximum(v_slice * beta2_t, math_ops.abs(grad))
   with ops.control_dependencies([v_t_slice]):
     v_t = scatter_update(v, indices, v_t_slice)
   # theta_t = theta - lr / (1 - beta1^t) * m_t / u_t
   var_slice = -lr_t / (1 - beta1_power) * (m_t_slice /
                                            (v_t_slice + epsilon_t))
   with ops.control_dependencies([var_slice]):
     var_update = scatter_add(var, indices, var_slice)
   return control_flow_ops.group(*[var_update, m_t, v_t])
예제 #5
0
    def _apply_sparse(self, grad, var):
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
        epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)

        # the following equations given in [1]
        # m_t = beta1 * m + (1 - beta1) * g_t
        m = self.get_slot(var, "m")
        m_t = state_ops.scatter_update(m, grad.indices,
                                       beta1_t * array_ops.gather(m, grad.indices) +
                                       (1. - beta1_t) * grad.values,
                                       use_locking=self._use_locking)
        m_t_slice = tf.gather(m_t, grad.indices)

        # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
        v = self.get_slot(var, "v")
        v_t = state_ops.scatter_update(v, grad.indices,
                                       beta2_t * array_ops.gather(v, grad.indices) +
                                       (1. - beta2_t) * tf.square(grad.values),
                                       use_locking=self._use_locking)
        v_prime = self.get_slot(var, "v_prime")
        v_t_slice = tf.gather(v_t, grad.indices)
        v_prime_slice = tf.gather(v_prime, grad.indices)
        v_t_prime = state_ops.scatter_update(v_prime, grad.indices, tf.maximum(v_prime_slice, v_t_slice))

        v_t_prime_slice = array_ops.gather(v_t_prime, grad.indices)
        var_update = state_ops.scatter_sub(var, grad.indices,
                                           lr_t * m_t_slice / (math_ops.sqrt(v_t_prime_slice) + epsilon_t),
                                           use_locking=self._use_locking)

        return control_flow_ops.group(*[var_update, m_t, v_t, v_t_prime])
 def f(params):
   index_values = [1, 3, 5, 6]
   indices = constant_op.constant(index_values, name="i")
   y = array_ops.gather(params, indices, name="y")
   index_values2 = [0, 2]
   indices2 = constant_op.constant(index_values2, name="i2")
   return array_ops.gather(y, indices2, name="y2")
예제 #7
0
 def _sparse_moving_average(self, x_tm1, idxs, b_t_, name, beta=.9):
   """
   Creates a moving average for a sparse variable.
   Inputs:
     x_tm1: the associated parameter (e.g. a weight matrix)
     idxs: the tensor representing the indices used
     b_t_: the value to accumulate (e.g. slices of the gradient)
     name: a string to use to retrieve it later (e.g. 'm')
     beta: the decay factor (defaults to .9)
   Outputs:
     a_t: the average after moving (same shape as x_tm1, not b_t_)
     t: the internal timestep (used to correct initialization bias)
   """
   
   a_tm1 = self._zeros_slot(x_tm1, '%s' % name, self._name)
   a_tm1_ = array_ops.gather(a_tm1, idxs)
   tm1 = self._zeros_idx_slot(x_tm1, '%s/tm1' % name, self._name)
   tm1_ = array_ops.gather(tm1, idxs)
   t = state_ops.scatter_add(tm1, idxs, tm1_*0+1, use_locking=self._use_locking)
   t_ = array_ops.gather(t, idxs)
   if beta < 1:
     beta_t = ops.convert_to_tensor(beta, name='%s/decay' % name)
     beta_t_ = beta_t * (1-beta_t**tm1_) / (1-beta_t**t_)
   else:
     beta_t_ = tm1_/t_
   a_t = state_ops.scatter_update(a_tm1, idxs, beta_t_*a_tm1_, use_locking=self._use_locking)
   a_t = state_ops.scatter_add(a_t, idxs, (1-beta_t)*b_t_, use_locking=self._use_locking)
   return a_t, t
예제 #8
0
  def _resource_apply_sparse(self, grad, var, indices):
    beta1_power, beta2_power = self._get_beta_accumulators()
    beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
    beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
    lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
    beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
    beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
    epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
    lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))

    # \\(m := beta1 * m + (1 - beta1) * g_t\\)
    m = self.get_slot(var, "m")
    m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad
    m_update_op = resource_variable_ops.resource_scatter_update(m.handle,
                                                                indices,
                                                                m_t_slice)

    # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
    v = self.get_slot(var, "v")
    v_t_slice = (beta2_t * array_ops.gather(v, indices) +
                 (1 - beta2_t) * math_ops.square(grad))
    v_update_op = resource_variable_ops.resource_scatter_update(v.handle,
                                                                indices,
                                                                v_t_slice)

    # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
    var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t)
    var_update_op = resource_variable_ops.resource_scatter_sub(var.handle,
                                                               indices,
                                                               var_slice)

    return control_flow_ops.group(var_update_op, m_update_op, v_update_op)
예제 #9
0
 def loop_fn(i):
   x1 = array_ops.gather(x, i)
   y1 = array_ops.gather(y, i)
   outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)]
   del output_dtypes[:]
   output_dtypes.extend([t.dtype for t in outputs])
   return outputs
예제 #10
0
 def _apply_sparse(self, g_t, x_tm1, prepare):
   """"""
   
   idxs, idxs_ = array_ops.unique(g_t.indices)
   g_t_ = math_ops.unsorted_segment_sum(g_t.values, idxs_, array_ops.size(idxs))
   updates = []
   
   if self._mu > 0:
     m_and_t = self._sparse_moving_average(x_tm1, idxs, g_t_, 'm', self._mu)
     m_t_ = array_ops.gather(m_and_t[0], idxs)
     gamma_t = ops.convert_to_tensor(self._gamma)
     m_bar_t_ = (1-gamma_t)*m_t_ + gamma_t*g_t_
     updates.extend(m_and_t)
   else:
     m_bar_t_ = g_t_
   
   if self._ups > 0:
     v_and_t = self._sparse_moving_average(x_tm1, idxs, g_t_**2, 'v', self._ups)
     v_t_ = array_ops.gather(v_and_t[0], idxs)
     eps_t = ops.convert_to_tensor(self._eps)
     v_bar_t_ = math_ops.sqrt(v_t_ + eps_t)
     updates.extend(v_and_t)
   else:
     v_bar_t_ = 1.
   
   lr_t = ops.convert_to_tensor(self._lr)
   s_t_ = lr_t * m_bar_t_ / v_bar_t_
   return [[s_t_, x_tm1, idxs, g_t]] + updates
예제 #11
0
 def loop_fn(i):
   image = array_ops.gather(images, i)
   label = array_ops.gather(labels, i)
   logits = array_ops.reshape(model(image, training=training), [-1])
   loss = losses.softmax_cross_entropy(
       logits=logits, onehot_labels=label, reduction=losses.Reduction.NONE)
   return gradient_ops.gradients(loss, variables.trainable_variables())
예제 #12
0
 def loop_fn(i):
   loop_inputs = [
       array_ops.expand_dims(array_ops.gather(x, i), 0) for x in inputs
   ]
   loop_init_state = rnn_cell.LSTMStateTuple(
       *[array_ops.expand_dims(array_ops.gather(x, i), 0) for x in init_state])
   return model_fn(loop_inputs, loop_init_state)
예제 #13
0
 def _apply_sparse(self, grad, var):
   if len(grad.indices.get_shape()) == 1:
     grad_indices = grad.indices
     grad_values = grad.values
   else:
     grad_indices = array_ops.reshape(grad.indices, [-1])
     grad_values = array_ops.reshape(grad.values, [-1, grad.values.get_shape()[-1].value])
   gidxs, metagidxs = array_ops.unique(grad_indices)
   sizegidxs = array_ops.size(gidxs)
   gvals = math_ops.unsorted_segment_sum(grad_values, metagidxs, sizegidxs)
   # m_t = mu * m + (1 - mu) * g_t
   m = self.get_slot(var, "m")
   m_scaled_g_values = gvals * (1 - self._mu_t)
   m_t = state_ops.scatter_update(m, gidxs,
                                  array_ops.gather(m, gidxs) * self._mu_t,
                                  use_locking=self._use_locking)
   m_t = state_ops.scatter_add(m_t, gidxs, m_scaled_g_values,
                               use_locking=self._use_locking)
   m_t_ = array_ops.gather(m_t, gidxs) / (1 - self._mu2_t * self._mu_power)
   # m_bar = mu * m_t + (1 - mu) * g_t
   m_bar = self._mu2_t * m_t_ + m_scaled_g_values / (1 - self._mu_power)
   var_update = state_ops.scatter_sub(var, gidxs,
                                    self._lr_t * m_bar,
                                    use_locking=self._use_locking)
   return control_flow_ops.group(*[var_update, m_t])
예제 #14
0
 def testHigherRankMaxNorm(self):
   np.random.seed(8)
   with self.test_session():
     for params_shape in (12,), (6, 3), (6, 2, 3):
       # Test embedding rank 0, 1, 2.
       # Note: the first dimension must be a common multiple of procs below.
       params = 2 * np.ones(params_shape)
       params_norm = params / np.sqrt(
           np.sum(
               params * params, tuple(range(params.ndim)[1:]), keepdims=True))
       for ids_shape in (), (3), (4, 3), (2, 3, 4):
         ids = np.random.randint(
             params.shape[0], size=np.prod(ids_shape,
                                           dtype=np.int64)).reshape(ids_shape)
         # Compare nonsharded to gather
         simple = embedding_ops.embedding_lookup(
             params, ids, max_norm=1.0).eval()
         self.assertAllEqual(simple, array_ops.gather(params_norm, ids).eval())
         # Run a few different sharded versions.
         for procs in 1, 2, 3:
           stride = procs * math_ops.range(params.shape[0] // procs)
           split_params = [
               array_ops.gather(params, stride + p) for p in xrange(procs)
           ]
           sharded = embedding_ops.embedding_lookup(
               split_params, ids, max_norm=1.0).eval()
           self.assertAllEqual(simple, sharded)
예제 #15
0
 def testTransform(self):
   # This tests all combinations of:
   #   - ids rank 0, 1, >1
   #   - params sharded/unsharded
   # It always applies max_norm.
   np.random.seed(8)
   l2_norm = 2.
   with self.test_session():
     # Param values are in [l2_norm, l2_norm+1) so it will always clip.
     params = np.random.rand(6, 3) + l2_norm
     params_norm = l2_norm * params / np.sqrt(
         np.sum(params * params, axis=1, keepdims=True))
     # Compute the norm of each embedding. This will change the embedding
     # rank to 0.
     params_norm = np.linalg.norm(params_norm, axis=1)
     transform = lambda x: linalg_ops.norm(x, axis=1)
     for ids_shape in (), (3), (4, 3), (2, 3, 4):
       # Test ids rank 0, 1, 2, 3.
       ids = np.random.randint(
           params.shape[0], size=np.prod(ids_shape,
                                         dtype=np.int64)).reshape(ids_shape)
       # Compare nonsharded to gather.
       simple = embedding_ops._embedding_lookup_and_transform(
           params, ids, max_norm=l2_norm, transform_fn=transform).eval()
       self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval())
       # Run a few different sharded versions.
       for procs in 1, 2, 3:
         stride = procs * math_ops.range(params.shape[0] // procs)
         split_params = [
             array_ops.gather(params, stride + p) for p in xrange(procs)
         ]
         sharded = embedding_ops._embedding_lookup_and_transform(
             split_params, ids, max_norm=l2_norm,
             transform_fn=transform).eval()
         self.assertAllEqual(simple, sharded)
예제 #16
0
  def _apply_sparse(self, grad, var):
    beta1_power, beta2_power = self._get_beta_accumulators()
    beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
    beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
    lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
    beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
    beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
    epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
    lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))

    # m := beta1 * m + (1 - beta1) * g_t
    m = self.get_slot(var, "m")
    m_t = state_ops.scatter_update(m, grad.indices,
                                   beta1_t * array_ops.gather(m, grad.indices) +
                                   (1 - beta1_t) * grad.values,
                                   use_locking=self._use_locking)

    # v := beta2 * v + (1 - beta2) * (g_t * g_t)
    v = self.get_slot(var, "v")
    v_t = state_ops.scatter_update(v, grad.indices,
                                   beta2_t * array_ops.gather(v, grad.indices) +
                                   (1 - beta2_t) * math_ops.square(grad.values),
                                   use_locking=self._use_locking)

    # variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))
    m_t_slice = array_ops.gather(m_t, grad.indices)
    v_t_slice = array_ops.gather(v_t, grad.indices)
    denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t
    var_update = state_ops.scatter_sub(var, grad.indices,
                                       lr * m_t_slice / denominator_slice,
                                       use_locking=self._use_locking)
    return control_flow_ops.group(var_update, m_t, v_t)
예제 #17
0
  def testHigherRank(self):
    # We check that scalar and empty indices shapes work as well
    shape = (2, 1, 3, 2)
    for indices_shape in (), (0,), (2, 0), (2, 3):
      for dtype in _TEST_TYPES:
        for axis in range(len(shape)):
          params = self._buildParams(np.random.randn(*shape), dtype)
          indices = np.random.randint(shape[axis], size=indices_shape)
          with self.cached_session(use_gpu=True) as sess:
            tf_params = constant_op.constant(params)
            tf_indices = constant_op.constant(indices)
            # Check that both positive and negative indices for axis work.
            tf_axis = constant_op.constant(axis)
            tf_negative_axis = constant_op.constant(-len(shape) + axis)
            gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
            gather_negative_axis = array_ops.gather(
                tf_params, tf_indices, axis=tf_negative_axis)
            gather_value, gather_negative_axis_value = sess.run(
                [gather, gather_negative_axis])
            gather_np = np.take(params, indices, axis)
            self.assertAllEqual(gather_np, gather_value)
            self.assertAllEqual(gather_np, gather_negative_axis_value)
            expected_shape = (params.shape[:axis] + indices.shape +
                              params.shape[axis + 1:])
            self.assertEqual(expected_shape, gather.shape)
            self.assertEqual(expected_shape, gather_negative_axis.shape)

            # Test gradients
            gather_grad = np.random.randn(
                *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype)
            if dtype.is_complex:
              gather_grad -= 1j * gather_grad
            params_grad, indices_grad, axis_grad = gradients_impl.gradients(
                gather, [tf_params, tf_indices, tf_axis], gather_grad)
            self.assertEqual(indices_grad, None)
            self.assertEqual(axis_grad, None)
            if dtype.is_integer:
              self.assertEqual(params_grad, None)
              continue
            # For axis 0, we are able to create an efficient IndexedSlices for
            # the gradient.
            if axis == 0:
              self.assertEqual(type(params_grad), ops.IndexedSlices)
              params_grad = ops.convert_to_tensor(params_grad)
            correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype)
            outer_dims = axis
            inner_dims = len(shape) - axis - 1
            gather_grad = gather_grad.reshape(
                shape[:axis] + (indices.size,) + shape[axis + 1:])
            for source_index, dest_index in enumerate(indices.flat):
              dest_slice = ((slice(None),) * outer_dims + (dest_index,) +
                            (slice(None),) * inner_dims)
              source_slice = ((slice(None),) * outer_dims + (source_index,) +
                              (slice(None),) * inner_dims)
              correct_params_grad[dest_slice] += gather_grad[source_slice]
            self.assertAllClose(
                correct_params_grad,
                self.evaluate(params_grad),
                atol=2e-6,
                rtol=2e-6)
예제 #18
0
  def _resource_apply_sparse(self, grad, var, indices):
    var_dtype = var.dtype.base_dtype
    lr_t = self._decayed_lr(var_dtype)

    beta_1_t = self._get_hyper('beta_1', var_dtype)
    beta_2_t = self._get_hyper('beta_2', var_dtype)
    local_step = math_ops.cast(self.iterations + 1, var_dtype)
    beta_1_power = math_ops.pow(beta_1_t, local_step)
    epsilon_t = self._get_hyper('epsilon', var_dtype)

    # m_t = beta1 * m + (1 - beta1) * g_t
    m = self.get_slot(var, 'm')
    m_slice = array_ops.gather(m, indices)
    m_t_slice = m_slice * beta_1_t + grad * (1 - beta_1_t)
    with ops.control_dependencies([m_t_slice]):
      m_t = self._resource_scatter_update(m, indices, m_t_slice)

    # u_t = max(beta2 * u, abs(g_t))
    v = self.get_slot(var, 'v')
    v_slice = array_ops.gather(v, indices)
    v_t_slice = math_ops.maximum(v_slice * beta_2_t, math_ops.abs(grad))
    with ops.control_dependencies([v_t_slice]):
      v_t = self._resource_scatter_update(v, indices, v_t_slice)
    # theta_t = theta - lr / (1 - beta1^t) * m_t / u_t
    var_slice = -lr_t / (1 - beta_1_power) * (
        m_t_slice / (v_t_slice + epsilon_t))
    with ops.control_dependencies([var_slice]):
      var_update = self._resource_scatter_add(var, indices, var_slice)
    return control_flow_ops.group(*[var_update, m_t, v_t])
예제 #19
0
 def testBadIndicesCPU(self):
   with test_util.force_cpu():
     params = [[0, 1, 2], [3, 4, 5]]
     with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"):
       self.evaluate(array_ops.gather(params, [[7]], axis=0))
     with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
       self.evaluate(array_ops.gather(params, [[7]], axis=1))
예제 #20
0
 def _prepare(self, grads_and_vars):
   """"""
   
   if self._lr is None:
     sTy = 0
     sTs = 0
     yTy = 0
     for g_t, x_tm1 in grads_and_vars:
       if g_t is None:
         continue
       with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
         if isinstance(g_t, ops.Tensor):
           g_tm1 = self.get_slot(x_tm1, 'g')
           s_tm1 = self.get_slot(x_tm1, 's')
           y_t = (g_t-g_tm1)
           sTy += math_ops.reduce_sum(s_tm1*y_t)
           sTs += math_ops.reduce_sum(s_tm1**2)
           yTy += math_ops.reduce_sum(y_t**2)
         else:
           idxs, idxs_ = array_ops.unique(g_t.indices)
           g_t_ = math_ops.unsorted_segment_sum(g_t.values, idxs_, array_ops.size(idxs))
           g_tm1 = self.get_slot(x_tm1, 'g')
           g_tm1_ = array_ops.gather(g_tm1, idxs)
           s_tm1 = self.get_slot(x_tm1, 's')
           s_tm1_ = array_ops.gather(s_tm1, idxs)
           y_t_ = (g_t_-g_tm1_)
           sTy += math_ops.reduce_sum(s_tm1_*y_t_)
           sTs += math_ops.reduce_sum(s_tm1_**2)
           yTy += math_ops.reduce_sum(y_t_**2)
     sTy = math_ops.abs(sTy)
     self._lr = sTs / (sTy + self._eps)
예제 #21
0
 def testString(self):
   params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
   with self.cached_session():
     self.assertAllEqual([b"qwer", b"uiop"],
                         array_ops.gather(params, 1, axis=0).eval())
     self.assertAllEqual([b"asdf", b"qwer"],
                         array_ops.gather(params, 0, axis=1).eval())
예제 #22
0
  def _resource_apply_sparse(self, grad, var, indices):
    var_dtype = var.dtype.base_dtype
    lr_t = self._decayed_lr(var_dtype)
    beta_1_t = self._get_hyper('beta_1', var_dtype)
    beta_2_t = self._get_hyper('beta_2', var_dtype)
    local_step = math_ops.cast(self.iterations + 1, var_dtype)
    beta_1_power = math_ops.pow(beta_1_t, local_step)
    beta_2_power = math_ops.pow(beta_2_t, local_step)
    epsilon_t = self._get_hyper('epsilon', var_dtype)
    lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))

    # m_t = beta1 * m + (1 - beta1) * g_t
    m = self.get_slot(var, 'm')
    m_scaled_g_values = grad * (1 - beta_1_t)
    m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking)
    with ops.control_dependencies([m_t]):
      m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
      # m_bar = (1 - beta1) * g_t + beta1 * m_t
      m_bar = m_scaled_g_values + beta_1_t * array_ops.gather(m_t, indices)

    # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
    v = self.get_slot(var, 'v')
    v_scaled_g_values = (grad * grad) * (1 - beta_2_t)
    v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking)
    with ops.control_dependencies([v_t]):
      v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)

    v_t_slice = array_ops.gather(v_t, indices)
    v_sqrt = math_ops.sqrt(v_t_slice)
    var_update = self._resource_scatter_add(var, indices,
                                            -lr * m_bar / (v_sqrt + epsilon_t))
    return control_flow_ops.group(*[var_update, m_bar, v_t])
예제 #23
0
 def testBadIndicesCPU(self):
   with self.session(use_gpu=False):
     params = [[0, 1, 2], [3, 4, 5]]
     with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"):
       array_ops.gather(params, [[7]], axis=0).eval()
     with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
       array_ops.gather(params, [[7]], axis=1).eval()
예제 #24
0
 def testHigherRankMaxNorm(self):
   np.random.seed(8)
   with self.cached_session():
     for params_shape in (12,), (6, 3), (6, 2, 3):
       # Test embedding rank 0, 1, 2.
       # Note: the first dimension must be a common multiple of procs below.
       params = 2 * np.ones(params_shape)
       params_norm = params / np.sqrt(
           np.sum(
               params * params, tuple(range(params.ndim)[1:]), keepdims=True))
       for ids_shape in (), (3), (4, 3), (2, 3, 4):
         ids = np.random.randint(
             params.shape[0], size=np.prod(ids_shape,
                                           dtype=np.int64)).reshape(ids_shape)
         # Compare nonsharded to gather
         simple = embedding_ops.embedding_lookup(
             params, ids, max_norm=1.0).eval()
         # assertAllClose is used here as different implementations of sqrt may
         # be used to compute each of the values being compared.  For example,
         # on AVX512 builds the embedding operation makes use of Eigen's fast
         # vectorized square root algorithm for doubles.  These different
         # implementations of sqrt are not guaranteed to produce exactly the
         # same results. Therefore, an exact comparison cannot be made.
         self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval())
         # Run a few different sharded versions.
         for procs in 1, 2, 3:
           stride = procs * math_ops.range(params.shape[0] // procs)
           split_params = [
               array_ops.gather(params, stride + p) for p in xrange(procs)
           ]
           sharded = embedding_ops.embedding_lookup(
               split_params, ids, max_norm=1.0).eval()
           self.assertAllEqual(simple, sharded)
  def _check_shapes_dynamic(self, operator, v, diag):
    """Return (v, diag) with Assert dependencies, which check shape."""
    checks = []
    with ops.op_scope([operator, v, diag], 'check_shapes'):
      s_v = array_ops.shape(v)
      r_op = operator.rank()
      r_v = array_ops.rank(v)
      if diag is not None:
        s_d = array_ops.shape(diag)
        r_d = array_ops.rank(diag)

      # Check tensor rank.
      checks.append(check_ops.assert_rank(v, r_op))
      if diag is not None:
        checks.append(check_ops.assert_rank(diag, r_op - 1))

      # Check batch shape
      checks.append(check_ops.assert_equal(
          operator.batch_shape(), array_ops.slice(s_v, [0], [r_v - 2])))
      if diag is not None:
        checks.append(check_ops.assert_equal(
            operator.batch_shape(), array_ops.slice(s_d, [0], [r_d - 1])))

      # Check event shape
      checks.append(check_ops.assert_equal(
          operator.vector_space_dimension(), array_ops.gather(s_v, r_v - 2)))
      if diag is not None:
        checks.append(check_ops.assert_equal(
            array_ops.gather(s_v, r_v - 1), array_ops.gather(s_d, r_d - 1)))

      v = control_flow_ops.with_dependencies(checks, v)
      if diag is not None:
        diag = control_flow_ops.with_dependencies(checks, diag)
      return v, diag
예제 #26
0
 def _apply_sparse_shared(self, grad, var, indices, scatter_add):
   beta1_power, beta2_power = self._get_beta_accumulators()
   beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
   beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
   lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
   beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
   beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
   epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
   lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
   # m_t = beta1 * m + (1 - beta1) * g_t
   m = self.get_slot(var, "m")
   m_scaled_g_values = grad * (1 - beta1_t)
   m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
   with ops.control_dependencies([m_t]):
     m_t = scatter_add(m, indices, m_scaled_g_values)
     # m_bar = (1 - beta1) * g_t + beta1 * m_t
     m_bar = m_scaled_g_values + beta1_t * array_ops.gather(m_t, indices)
   # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
   v = self.get_slot(var, "v")
   v_scaled_g_values = (grad * grad) * (1 - beta2_t)
   v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
   with ops.control_dependencies([v_t]):
     v_t = scatter_add(v, indices, v_scaled_g_values)
   v_t_slice = array_ops.gather(v_t, indices)
   v_sqrt = math_ops.sqrt(v_t_slice)
   var_update = scatter_add(var, indices, -lr * m_bar / (v_sqrt + epsilon_t))
   return control_flow_ops.group(*[var_update, m_bar, v_t])
예제 #27
0
  def _apply_sparse_shared(self,
                           grad,
                           var,
                           indices,
                           scatter_update,
                           scatter_sub):
    beta1_power, beta2_power = self._get_beta_accumulators()
    beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
    beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
    lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
    beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
    beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
    epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
    lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))

    # \\(m := beta1 * m + (1 - beta1) * g_t\\)
    m = self.get_slot(var, "m")
    m_t = scatter_update(m, indices,
                         beta1_t * array_ops.gather(m, indices) +
                         (1 - beta1_t) * grad)

    # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
    v = self.get_slot(var, "v")
    v_t = scatter_update(v, indices,
                         beta2_t * array_ops.gather(v, indices) +
                         (1 - beta2_t) * math_ops.square(grad))

    # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
    m_t_slice = array_ops.gather(m_t, indices)
    v_t_slice = array_ops.gather(v_t, indices)
    denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t
    var_update = scatter_sub(var, indices,
                             lr * m_t_slice / denominator_slice)
    return control_flow_ops.group(var_update, m_t, v_t)
예제 #28
0
def _SparseUpdate(variable, gradients, accum, linear, base_lr,
                  lr_power, l1, l2):
  """Sparse Update "variable", "accum", "linear" based on sparse "gradients".

  See the description in _Update.

  Args:
    variable: A Variable.
    gradients: A Sparse Tensor
    accum: A Variable containing the sum of the squares of gradients.
    linear: A Variable containing approximation info.
    base_lr: A constant represents base learning rate.
    lr_power: A constant is used to adjust learning rate.
    l1: A constant represents l1 regularization strength.
    l2: A constant represents l2 regularization strength.

  Returns:
    A group op including three ScatterUpdate ops:
      1. ScatterUpdate for "accum"
      2. ScatterUpdate for "linear"
      3. ScatterUpdate for "variable"
  """
  assert isinstance(gradients, ops.IndexedSlices)
  with ops.name_scope("sparse_update_" + variable.op.name) as scope:
    dtype = variable.dtype.base_dtype
    base_lr = ops.convert_to_tensor(base_lr, dtype=dtype)
    lr_power = ops.convert_to_tensor(lr_power, dtype=dtype)
    l1 = ops.convert_to_tensor(l1, dtype=dtype)
    l2 = ops.convert_to_tensor(l2, dtype=dtype)

    # Compute the new value for the accumulator
    previous_accum = array_ops.gather(accum, gradients.indices)
    sqr_grad = gradients.values * gradients.values
    accum_updated = sqr_grad + previous_accum

    # Compute the new linear
    neg_lr_power = math_ops.neg(lr_power)
    sigma = math_ops.pow(accum_updated, neg_lr_power) - math_ops.pow(
        previous_accum, neg_lr_power)
    sigma /= base_lr
    variable_slice = array_ops.gather(variable, gradients.indices)
    proximal_adjust = sigma * variable_slice
    linear_slice = array_ops.gather(linear, gradients.indices)
    linear_updated = linear_slice + gradients.values - proximal_adjust

    # Compute the new "variable"
    variable_updated = _Compute(accum_updated, linear_updated, base_lr,
                                lr_power, l1, l2)

    with ops.control_dependencies([sigma]):
      accum_update_op = state_ops.scatter_update(accum, gradients.indices,
                                                accum_updated)
    linear_update_op = state_ops.scatter_update(linear, gradients.indices,
                                               linear_updated)
    variable_update_op = state_ops.scatter_update(variable, gradients.indices,
                                                 variable_updated)
    group_op = control_flow_ops.group(linear_update_op, accum_update_op,
                                      variable_update_op, name=scope)
    return group_op
예제 #29
0
 def testUInt32AndUInt64(self):
   for unsigned_type in (dtypes.uint32, dtypes.uint64):
     params = self._buildParams(
         np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type)
     with self.cached_session():
       self.assertAllEqual([7, 8, 9],
                           array_ops.gather(params, 1, axis=0).eval())
       self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1).eval())
예제 #30
0
 def testSkipEagerErrors(self):
   if context.executing_eagerly():
     return
   with self.assertRaisesRegexp(ValueError, r"tf\.gather does not allow.*"):
     array_ops.gather(
         params=[1, 2],
         batch_dims=1,
         indices=array_ops.placeholder(dtypes.int32))
예제 #31
0
 def loop_fn(i, use_pfor):
   image = array_ops.gather(images, i)
   logits = array_ops.reshape(model(image, training=training), [-1])
   return gradients.jacobian(
       logits, variables.trainable_variables(), use_pfor=use_pfor)
예제 #32
0
 def _forward(self, x):
     return array_ops.gather(x, self.permutation, axis=-1)
예제 #33
0
 def loop_fn(i, use_pfor):
   inp_i = array_ops.expand_dims(array_ops.gather(inp, i), 0)
   output = array_ops.reshape(model(inp_i), [-1])
   return gradients.jacobian(
       output, variables.trainable_variables(), use_pfor=use_pfor)
예제 #34
0
    def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
        var_dtype = var.dtype.base_dtype
        lr_t = array_ops.identity(self._get_hyper('learning_rate', var_dtype))
        beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
        beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
        epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)
        m = self.get_slot(var, 'm')
        v = self.get_slot(var, 'v')
        local_step = math_ops.cast(self.iterations + 1, var_dtype)
        next_step = math_ops.cast(self.iterations + 2, var_dtype)
        decay_base = math_ops.cast(0.96, var_dtype)

        # Learning rate multipliers
        if self.lr_multipliers is not None:
            lr_t = _apply_lr_multiplier(self, lr_t, var)

        momentum_cache_t = beta_1_t * (
            1. - 0.5 *
            (math_ops.pow(decay_base, self._initial_decay * local_step)))
        momentum_cache_t_1 = beta_1_t * (
            1. - 0.5 *
            (math_ops.pow(decay_base, self._initial_decay * next_step)))
        m_schedule_new = math_ops.cast(self._m_cache_read,
                                       var_dtype) * momentum_cache_t
        if var_dtype is self._m_cache.dtype:
            m_schedule_new = array_ops.identity(
                state_ops.assign(self._m_cache,
                                 m_schedule_new,
                                 use_locking=self._use_locking))
        m_schedule_next = m_schedule_new * momentum_cache_t_1

        m_scaled_g_values = grad * (1. - beta_1_t)
        m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking)
        with ops.control_dependencies([m_t]):
            m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
            m_t_slice = array_ops.gather(m_t, indices)

        m_t_prime = m_t_slice / (1. - m_schedule_next)
        g_prime = grad / (1. - m_schedule_new)
        m_t_bar = (1. - momentum_cache_t) * g_prime + (momentum_cache_t_1 *
                                                       m_t_prime)

        v_scaled_g_values = (grad * grad) * (1. - beta_2_t)
        v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking)

        with ops.control_dependencies([v_t]):
            v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
            v_t_slice = array_ops.gather(v_t, indices)

        v_t_prime_denominator = 1. - math_ops.pow(beta_2_t, local_step)
        v_t_prime = v_t_slice / v_t_prime_denominator
        v_prime_sqrt_plus_eps = math_ops.sqrt(v_t_prime) + epsilon_t

        var_t = self._resource_scatter_add(
            var, indices, -self.eta_t * lr_t * m_t_bar / v_prime_sqrt_plus_eps)

        # Weight decays
        if var.name in self.weight_decays.keys():
            var_t = _apply_weight_decays(self, var, var_t)

        var_update = state_ops.assign(var,
                                      var_t,
                                      use_locking=self._use_locking)

        # Cosine annealing
        (iteration_done, t_cur_update,
         eta_t_update) = _update_t_cur_eta_t_v2(self, lr_t, var)
        if iteration_done and not self._init_notified:
            self._init_notified = True

        updates = [var_update, m_t_bar, v_t]
        if iteration_done:
            updates += [t_cur_update]
        if self.use_cosine_annealing and iteration_done:
            updates += [eta_t_update]
        return control_flow_ops.group(*updates)
예제 #35
0
    def training_graph(self,
                       input_data,
                       input_labels,
                       num_trainers=1,
                       trainer_id=0,
                       **tree_kwargs):
        """Constructs a TF graph for training a random forest.

    Args:
      input_data: A tensor or dict of string->Tensor for input data.
      input_labels: A tensor or placeholder for labels associated with
        input_data.
      num_trainers: Number of parallel trainers to split trees among.
      trainer_id: Which trainer this instance is.
      **tree_kwargs: Keyword arguments passed to each tree's training_graph.

    Returns:
      The last op in the random forest training graph.

    Raises:
      NotImplementedError: If trying to use bagging with sparse features.
    """
        processed_dense_features, processed_sparse_features, data_spec = (
            data_ops.ParseDataTensorOrDict(input_data))

        if input_labels is not None:
            labels = data_ops.ParseLabelTensorOrDict(input_labels)

        data_spec = data_spec or self.get_default_data_spec(input_data)

        tree_graphs = []
        trees_per_trainer = self.params.num_trees / num_trainers
        tree_start = int(trainer_id * trees_per_trainer)
        tree_end = int((trainer_id + 1) * trees_per_trainer)
        for i in range(tree_start, tree_end):
            with ops.device(self.variables.device_dummies[i].device):
                seed = self.params.base_random_seed
                if seed != 0:
                    seed += i
                # If using bagging, randomly select some of the input.
                tree_data = processed_dense_features
                tree_labels = labels
                if self.params.bagging_fraction < 1.0:
                    # TODO(gilberth): Support bagging for sparse features.
                    if processed_sparse_features is not None:
                        raise NotImplementedError(
                            'Bagging not supported with sparse features.')
                    # TODO(thomaswc): This does sampling without replacement.  Consider
                    # also allowing sampling with replacement as an option.
                    batch_size = array_ops.strided_slice(
                        array_ops.shape(processed_dense_features), [0], [1])
                    r = random_ops.random_uniform(batch_size, seed=seed)
                    mask = math_ops.less(
                        r,
                        array_ops.ones_like(r) * self.params.bagging_fraction)
                    gather_indices = array_ops.squeeze(array_ops.where(mask),
                                                       axis=[1])
                    # TODO(thomaswc): Calculate out-of-bag data and labels, and store
                    # them for use in calculating statistics later.
                    tree_data = array_ops.gather(processed_dense_features,
                                                 gather_indices)
                    tree_labels = array_ops.gather(labels, gather_indices)
                if self.params.bagged_features:
                    if processed_sparse_features is not None:
                        raise NotImplementedError(
                            'Feature bagging not supported with sparse features.'
                        )
                    tree_data = self._bag_features(i, tree_data)

                tree_graphs.append(self.trees[i].training_graph(
                    tree_data,
                    tree_labels,
                    seed,
                    data_spec=data_spec,
                    sparse_features=processed_sparse_features,
                    **tree_kwargs))

        return control_flow_ops.group(*tree_graphs, name='train')
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
  """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
      matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
      X` then `perm = argmax(P)`.
    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
      `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[...,
        tf.newaxis])[..., 0]`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
        actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_solve').

  Returns:
    x: The `X` in `A @ X = RHS`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[1., 2],
        [3, 4]],
       [[7, 8],
        [3, 4]]]
  inv_x = tf.linalg.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

  with ops.name_scope(name or 'lu_solve'):
    lower_upper = ops.convert_to_tensor(
        lower_upper, dtype_hint=dtypes.float32, name='lower_upper')
    perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm')
    rhs = ops.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs')

    assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args)
    if assertions:
      with ops.control_dependencies(assertions):
        lower_upper = array_ops.identity(lower_upper)
        perm = array_ops.identity(perm)
        rhs = array_ops.identity(rhs)

    if (rhs.shape.rank == 2 and perm.shape.rank == 1):
      # Both rhs and perm have scalar batch_shape.
      permuted_rhs = array_ops.gather(rhs, perm, axis=-2)
    else:
      # Either rhs or perm have non-scalar batch_shape or we can't determine
      # this information statically.
      rhs_shape = array_ops.shape(rhs)
      broadcast_batch_shape = array_ops.broadcast_dynamic_shape(
          rhs_shape[:-2],
          array_ops.shape(perm)[:-1])
      d, m = rhs_shape[-2], rhs_shape[-1]
      rhs_broadcast_shape = array_ops.concat([broadcast_batch_shape, [d, m]],
                                             axis=0)

      # Tile out rhs.
      broadcast_rhs = array_ops.broadcast_to(rhs, rhs_broadcast_shape)
      broadcast_rhs = array_ops.reshape(broadcast_rhs, [-1, d, m])

      # Tile out perm and add batch indices.
      broadcast_perm = array_ops.broadcast_to(perm, rhs_broadcast_shape[:-1])
      broadcast_perm = array_ops.reshape(broadcast_perm, [-1, d])
      broadcast_batch_size = math_ops.reduce_prod(broadcast_batch_shape)
      broadcast_batch_indices = array_ops.broadcast_to(
          math_ops.range(broadcast_batch_size)[:, array_ops.newaxis],
          [broadcast_batch_size, d])
      broadcast_perm = array_ops.stack(
          [broadcast_batch_indices, broadcast_perm], axis=-1)

      permuted_rhs = array_ops.gather_nd(broadcast_rhs, broadcast_perm)
      permuted_rhs = array_ops.reshape(permuted_rhs, rhs_broadcast_shape)

    lower = set_diag(
        band_part(lower_upper, num_lower=-1, num_upper=0),
        array_ops.ones(
            array_ops.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
    return triangular_solve(
        lower_upper,  # Only upper is accessed.
        triangular_solve(lower, permuted_rhs),
        lower=False)
예제 #37
0
def boolean_mask(data, mask, name=None):
    """Applies a boolean mask to `data` without flattening the mask dimensions.

  Returns a potentially ragged tensor that is formed by retaining the elements
  in `data` where the corresponding value in `mask` is `True`.

  * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]`

     Where `j` is the `i`th `True` entry of `mask[a1...aA]`.

  Note that `output` preserves the mask dimensions `a1...aA`; this differs
  from `tf.boolean_mask`, which flattens those dimensions.

  Args:
    data: A potentially ragged tensor.
    mask: A potentially ragged boolean tensor.  `mask`'s shape must be a prefix
      of `data`'s shape.  `rank(mask)` must be known statically.
    name: A name prefix for the returned tensor (optional).

  Returns:
    A potentially ragged tensor that is formed by retaining the elements in
    `data` where the corresponding value in `mask` is `True`.

    * `rank(output) = rank(data)`.
    * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`.

  Raises:
    ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is
      not a prefix of `data.shape`.

  #### Examples:

  >>> # Aliases for True & False so data and mask line up.
  >>> T, F = (True, False)

  >>> tf.ragged.boolean_mask(  # Mask a 2D Tensor.
  ...     data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
  ...     mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list()
  [[1, 3], [], [7]]

  >>> tf.ragged.boolean_mask(  # Mask a 2D RaggedTensor.
  ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
  ...     tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list()
  [[3], [], [5, 6]]

  >>> tf.ragged.boolean_mask(  # Mask rows of a 2D RaggedTensor.
  ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
  ...     tf.ragged.constant([True, False, True])).to_list()
  [[1, 2, 3], [5, 6]]
  """
    with ops.name_scope(name, 'RaggedMask', [data, mask]):
        # Convert inputs to tensors.
        data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data,
                                                                name='data')
        mask = ragged_tensor.convert_to_tensor_or_ragged_tensor(mask,
                                                                dtypes.bool,
                                                                name='mask')
        row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes(
            data, mask, return_dtype=True)

        # Get static rank of mask.
        if mask.shape.ndims is None:
            raise ValueError('mask.shape.ndims must be known statically.')
        elif mask.shape.ndims == 0:
            raise ValueError('mask cannot be scalar.')

        # If mask is ragged, then recurse with a non-ragged mask.
        if ragged_tensor.is_ragged(mask):
            if not ragged_tensor.is_ragged(data):
                data = ragged_tensor.RaggedTensor.from_tensor(
                    data,
                    ragged_rank=mask.ragged_rank,
                    row_splits_dtype=mask.row_splits.dtype)
            # Check that mask.nested_row_splits is a prefix of
            # data.nested_row_splits.
            splits_list = [
                mask.nested_row_splits,
                data.nested_row_splits[:mask.ragged_rank]
            ]
            with ops.control_dependencies(
                    ragged_util.assert_splits_match(splits_list)):
                # Strip off ragged `splits` until `mask` is non-ragged.  Keep the splits
                # that we strip off in `splits`, so we can add them back on after
                # we recursively mask the non-ragged data.
                splits = []
                while ragged_tensor.is_ragged(mask):
                    if mask.shape.ndims > 2:
                        splits.append(mask.row_splits)
                    else:
                        # Count the number of True mask values in each row to find the
                        # lengths of the filtered rows; then convert to splits.
                        int_mask = ragged_functional_ops.map_flat_values(
                            math_ops.cast, mask, dtype=row_splits_dtype)
                        masked_row_lengths = ragged_math_ops.reduce_sum(
                            int_mask, axis=1)
                        splits.append(
                            ragged_util.lengths_to_splits(masked_row_lengths))
                    mask = mask.values
                    data = data.values

                # Recursively apply the nested non-ragged mask to the nested data.
                masked_values = boolean_mask(data, mask)

                # Add the ragged `splits` back to the result.
                masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits(
                    masked_values, splits, validate=False)

                return masked_values

        # If mask is non-ragged and has rank 1, and data is ragged, then build a
        # ragged tensor with the indicated rows.
        elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1:
            # Get the masked splits: first get the length of each row, then filter
            # out the rows that we are deleting, and convert that filtered set of
            # masks back to a splits tensor.
            lengths = data.row_lengths()
            masked_lengths = array_ops.boolean_mask(lengths, mask)
            masked_splits = ragged_util.lengths_to_splits(masked_lengths)

            # Get the masked values: first get row ids corresponding to each
            # value, then use tf.gather to build a boolean mask that's false for
            # values that come from rows that we are deleting, and use that mask to
            # construct the masked values tensor.
            segment_ids = segment_id_ops.row_splits_to_segment_ids(
                data.row_splits)
            segment_mask = array_ops.gather(mask, segment_ids)
            masked_values = boolean_mask(data.values, segment_mask)

            return ragged_tensor.RaggedTensor.from_row_splits(masked_values,
                                                              masked_splits,
                                                              validate=False)

        # If mask is non-ragged and has rank>1, then convert it to be ragged,
        # with a ragged rank matching data.
        if ragged_tensor.is_ragged(data):
            mask = ragged_tensor.RaggedTensor.from_tensor(
                mask,
                ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
                row_splits_dtype=data.row_splits.dtype)
            return boolean_mask(data, mask)

        # Otherwise, data and mask are both `Tensor`s.
        else:
            # Apply `boolean_mask` to get the masked values.
            masked_values = array_ops.boolean_mask(data, mask)

            if mask.shape.ndims >= 2:
                # Add the innermost ragged dimension.  For each innermost cell, get the
                # number of values it contains.  Then flatten that to get a list of
                # cell lengths, and convert it to splits.  Finally, combine the splits
                # and values to get the innermost ragged tensor.
                masked_lengths = math_ops.count_nonzero(mask,
                                                        axis=-1,
                                                        dtype=row_splits_dtype)
                flattened_masked_lengths = array_ops.reshape(
                    masked_lengths, [-1])
                masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
                    masked_values, flattened_masked_lengths, validate=False)

                # Wrap remaining ragged dimensions.
                if mask.shape.ndims > 2:
                    mask_shape = array_ops.shape(mask,
                                                 out_type=row_splits_dtype)
                    split_size = math_ops.cumprod(mask_shape) + 1
                    for dim in range(mask.shape.ndims - 3, -1, -1):
                        elt_size = mask_shape[dim + 1]
                        masked_splits = math_ops.range(
                            split_size[dim]) * elt_size
                        masked_values = ragged_tensor.RaggedTensor.from_row_splits(
                            masked_values, masked_splits, validate=False)

            return masked_values
예제 #38
0
    def decoder_fn(time, cell_state, cell_input, cell_output, context_state):
        """Decoder function used in the `dynamic_rnn_decoder` for inference.
        The main difference between this decoder function and the `decoder_fn` in
        `attention_decoder_fn_train` is how `next_cell_input` is calculated. In
        decoder function we calculate the next input by applying an argmax across
        the feature dimension of the output from the decoder. This is a
        greedy-search approach. (Bahdanau et al., 2014) & (Sutskever et al., 2014)
        use beam-search instead.
        Args:
            time: positive integer constant reflecting the current timestep.
            cell_state: state of RNNCell.
            cell_input: input provided by `dynamic_rnn_decoder`.
            cell_output: output of RNNCell.
            context_state: context state provided by `dynamic_rnn_decoder`.
        Returns:
            A tuple (done, next state, next input, emit output, next context state)
            where:
            done: A boolean vector to indicate which sentences has reached a
            `end_of_sequence_id`. This is used for early stopping by the
            `dynamic_rnn_decoder`. When `time>=maximum_length` a boolean vector with
            all elements as `true` is returned.
            next state: `cell_state`, this decoder function does not modify the
            given state.
            next input: The embedding from argmax of the `cell_output` is used as
            `next_input`.
            emit output: If `output_fn is None` the supplied `cell_output` is
            returned, else the `output_fn` is used to update the `cell_output`
            before calculating `next_input` and returning `cell_output`.
            next context state: `context_state`, this decoder function does not
            modify the given context state. The context state could be modified when
            applying e.g. beam search.
        Raises:
            ValueError: if cell_input is not None.
        """
        with ops.name_scope(
                name, "attention_decoder_fn_inference",
            [time, cell_state, cell_input, cell_output, context_state]):
            if cell_input is not None:
                raise ValueError(
                    "Expected cell_input to be None, but saw: %s" % cell_input)
            if cell_output is None:
                # invariant that this is time == 0
                next_input_id = array_ops.ones([
                    batch_size,
                ], dtype=dtype) * (start_of_sequence_id)
                done = array_ops.zeros([
                    batch_size,
                ], dtype=dtypes.bool)
                cell_state = encoder_state
                cell_output = array_ops.zeros([num_decoder_symbols],
                                              dtype=dtypes.float32)
                cell_input = array_ops.gather(embeddings, next_input_id)

                # init attention
                attention = _init_attention(encoder_state)
                # init context state
                log_beam_probs = tensor_array_ops.TensorArray(
                    dtype=dtypes.float32,
                    tensor_array_name="log_beam_probs",
                    size=maximum_length,
                    dynamic_size=True,
                    infer_shape=False)
                beam_parents = tensor_array_ops.TensorArray(
                    dtype=dtypes.int32,
                    tensor_array_name="beam_parents",
                    size=maximum_length,
                    dynamic_size=True,
                    infer_shape=False)
                beam_symbols = tensor_array_ops.TensorArray(
                    dtype=dtypes.int32,
                    tensor_array_name="beam_symbols",
                    size=maximum_length,
                    dynamic_size=True,
                    infer_shape=False)
                result_probs = tensor_array_ops.TensorArray(
                    dtype=dtypes.float32,
                    tensor_array_name="result_probs",
                    size=maximum_length,
                    dynamic_size=True,
                    infer_shape=False)
                result_parents = tensor_array_ops.TensorArray(
                    dtype=dtypes.int32,
                    tensor_array_name="result_parents",
                    size=maximum_length,
                    dynamic_size=True,
                    infer_shape=False)
                result_symbols = tensor_array_ops.TensorArray(
                    dtype=dtypes.int32,
                    tensor_array_name="result_symbols",
                    size=maximum_length,
                    dynamic_size=True,
                    infer_shape=False)
                context_state = (log_beam_probs, beam_parents, beam_symbols,
                                 result_probs, result_parents, result_symbols)
            else:
                # construct attention
                attention = attention_construct_fn(cell_output, attention_keys,
                                                   attention_values)
                cell_output = attention

                # beam search decoder
                (log_beam_probs, beam_parents, beam_symbols, result_probs,
                 result_parents, result_symbols) = context_state

                cell_output = output_fn(cell_output)  # logits
                cell_output = nn_ops.softmax(cell_output)

                cell_output = array_ops.split(cell_output,
                                              [2, num_decoder_symbols - 2],
                                              1)[1]

                tmp_output = array_ops.gather(
                    cell_output,
                    math_ops.range(origin_batch) * beam_size)

                probs = control_flow_ops.cond(
                    math_ops.equal(time, ops.convert_to_tensor(1, dtype)),
                    lambda: math_ops.log(tmp_output + ops.convert_to_tensor(
                        1e-20, dtypes.float32)),
                    lambda: math_ops.log(cell_output + ops.convert_to_tensor(
                        1e-20, dtypes.float32)) + array_ops.reshape(
                            log_beam_probs.read(time - 2), [-1, 1]))

                probs = array_ops.reshape(probs, [origin_batch, -1])
                best_probs, indices = nn_ops.top_k(probs, beam_size * 2)
                #indices = array_ops.reshape(indices, [-1])
                indices_flatten = array_ops.reshape(indices, [
                    -1
                ]) + array_ops.reshape(
                    array_ops.concat([
                        array_ops.reshape(
                            math_ops.range(origin_batch) *
                            ((num_decoder_symbols - 2) * beam_size), [-1, 1])
                    ] * (beam_size * 2), 1), [origin_batch * beam_size * 2])
                best_probs_flatten = array_ops.reshape(best_probs, [-1])

                symbols = indices_flatten % (num_decoder_symbols - 2)
                symbols = symbols + 2
                parents = indices_flatten // (num_decoder_symbols - 2)

                probs_wo_eos = best_probs + 1e5 * math_ops.cast(
                    math_ops.cast(
                        (indices % (num_decoder_symbols - 2) + 2) -
                        end_of_sequence_id, dtypes.bool), dtypes.float32)

                best_probs_wo_eos, indices_wo_eos = nn_ops.top_k(
                    probs_wo_eos, beam_size)

                indices_wo_eos = array_ops.reshape(
                    indices_wo_eos, [-1]) + array_ops.reshape(
                        array_ops.concat([
                            array_ops.reshape(
                                math_ops.range(origin_batch) *
                                (beam_size * 2), [-1, 1])
                        ] * beam_size, 1), [origin_batch * beam_size])

                _probs = array_ops.gather(best_probs_flatten, indices_wo_eos)
                _symbols = array_ops.gather(symbols, indices_wo_eos)
                _parents = array_ops.gather(parents, indices_wo_eos)

                log_beam_probs = log_beam_probs.write(time - 1, _probs)
                beam_symbols = beam_symbols.write(time - 1, _symbols)
                beam_parents = beam_parents.write(time - 1, _parents)
                result_probs = result_probs.write(time - 1, best_probs_flatten)
                result_symbols = result_symbols.write(time - 1, symbols)
                result_parents = result_parents.write(time - 1, parents)

                next_input_id = array_ops.reshape(_symbols, [batch_size])

                state_size = int(cell_state[0].get_shape().with_rank(2)[1])
                attn_size = int(attention.get_shape().with_rank(2)[1])
                state = []
                for j in cell_state:
                    state.append(
                        array_ops.reshape(array_ops.gather(j, _parents),
                                          [-1, state_size]))
                cell_state = tuple(state)
                attention = array_ops.reshape(
                    array_ops.gather(attention, _parents), [-1, attn_size])

                done = math_ops.equal(next_input_id, end_of_sequence_id)
                cell_input = array_ops.gather(embeddings, next_input_id)

            # combine cell_input and attention
            next_input = array_ops.concat([cell_input, attention], 1)

            # if time > maxlen, return all true vector
            done = control_flow_ops.cond(
                math_ops.greater(time, maximum_length),
                lambda: array_ops.ones([
                    batch_size,
                ], dtype=dtypes.bool),
                lambda: array_ops.zeros([
                    batch_size,
                ], dtype=dtypes.bool))
            return (done, cell_state, next_input, cell_output,
                    (log_beam_probs, beam_parents, beam_symbols, result_probs,
                     result_parents, result_symbols))  #context_state)
예제 #39
0
 def loop_fn(i):
   return model_fn(array_ops.expand_dims(array_ops.gather(inp, i), 0))
예제 #40
0
def _SparseSegmentSumGrad(op, grad):
    """Gradient for SparseSegmentSum."""
    input_rows = array_ops.shape(op.inputs[0])[0]
    return (math_ops.unsorted_segment_sum(array_ops.gather(grad, op.inputs[2]),
                                          op.inputs[1],
                                          input_rows), None, None)
예제 #41
0
 def get(self, iteration):
     return array_ops.gather(self._index, iteration)
예제 #42
0
def _UnsortedSegmentSumGrad(op, grad):
    """Gradient for SegmentSum."""
    return array_ops.gather(grad, op.inputs[1]), None, None
예제 #43
0
 def _event_shape(self):
     return array_ops.gather(array_ops.shape(self._mean_val),
                             [array_ops.rank(self._mean_val) - 1])
예제 #44
0
    def __init__(self, mu, sigma=None, sigma_chol=None, name=None):
        """Multivariate Normal distributions on `R^k`.

    User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`)
    with the last dimension having length `k`.

    User must provide exactly one of `sigma` (the covariance matrices) or
    `sigma_chol` (the cholesky decompositions of the covariance matrices).
    `sigma` or `sigma_chol` must be of rank `N+2`.  The last two dimensions
    must both have length `k`.  The first `N` dimensions correspond to batch
    indices.

    If `sigma_chol` is not provided, the batch cholesky factorization of `sigma`
    is calculated for you.

    The shapes of `mu` and `sigma` must match for the first `N` dimensions.

    Regardless of which parameter is provided, the covariance matrices must all
    be **positive definite** (an error is raised if one of them is not).

    Args:
      mu: (N+1)-D.  `float` or `double` tensor, the means of the distributions.
      sigma: (N+2)-D.  (optional) `float` or `double` tensor, the covariances
        of the distribution(s).  The first `N+1` dimensions must match
        those of `mu`.  Must be batch-positive-definite.
      sigma_chol: (N+2)-D.  (optional) `float` or `double` tensor, a
        lower-triangular factorization of `sigma`
        (`sigma = sigma_chol . sigma_chol^*`).  The first `N+1` dimensions
        must match those of `mu`.  The tensor itself need not be batch
        lower triangular: we ignore the upper triangular part.  However,
        the batch diagonals must be positive (i.e., sigma_chol must be
        batch-positive-definite).
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if neither sigma nor sigma_chol is provided.
      TypeError: if mu and sigma (resp. sigma_chol) are different dtypes.
    """
        if (sigma is None) == (sigma_chol is None):
            raise ValueError(
                "Exactly one of sigma and sigma_chol must be provided")

        with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"):
            sigma_or_half = sigma_chol if sigma is None else sigma

            mu = ops.convert_to_tensor(mu)
            sigma_or_half = ops.convert_to_tensor(sigma_or_half)

            contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half))

            with ops.control_dependencies(
                [_assert_compatible_shapes(mu, sigma_or_half)]):
                mu = array_ops.identity(mu, name="mu")

                # Store the dimensionality of the MVNs
                self._k = array_ops.gather(array_ops.shape(mu),
                                           array_ops.rank(mu) - 1)

                if sigma_chol is not None:
                    # Ensure we only keep the lower triangular part.
                    sigma_chol = array_ops.batch_matrix_band_part(sigma_chol,
                                                                  num_lower=-1,
                                                                  num_upper=0)
                    sigma_det = _determinant_from_sigma_chol(sigma_chol)
                    with ops.control_dependencies(
                        [_assert_batch_positive_definite(sigma_chol)]):
                        self._sigma = math_ops.batch_matmul(sigma_chol,
                                                            sigma_chol,
                                                            adj_y=True,
                                                            name="sigma")
                        self._sigma_chol = array_ops.identity(
                            sigma_chol, "sigma_chol")
                        self._sigma_det = array_ops.identity(
                            sigma_det, "sigma_det")
                        self._mu = array_ops.identity(mu, "mu")
                else:  # sigma is not None
                    sigma_chol = linalg_ops.batch_cholesky(sigma)
                    sigma_det = _determinant_from_sigma_chol(sigma_chol)
                    # batch_cholesky checks for PSD; so we can just use it here.
                    with ops.control_dependencies([sigma_chol]):
                        self._sigma = array_ops.identity(sigma, "sigma")
                        self._sigma_chol = array_ops.identity(
                            sigma_chol, "sigma_chol")
                        self._sigma_det = array_ops.identity(
                            sigma_det, "sigma_det")
                        self._mu = array_ops.identity(mu, "mu")
예제 #45
0
 def loop_fn(i):
   image = array_ops.gather(images, i)
   return model(image, training=training)
예제 #46
0
def _rank_resample(weights, biases, inputs, sampled_values, num_resampled,
                   resampling_temperature, partition_strategy):
  """A helper function for rank_sampled_softmax_loss.

  This computes, for each i in `sampled_values`,

      log(sum_j exp((w_i * x_j + b_i) / resampling_temperature))

  where w_i, b_i are the weight and bias of the i-th class, repsectively,
  and j ranges over the rows of `inputs`. For efficiency, we rearrange the
  computation to

      log(sum_j exp(w_i * (x_j / resampling_temperature))) +
          b_i / resampling_temperature.

  This translates to the following batched computation using tensorflow ops:

      reduce_logsumexp(matmul(embeddings,
                       transpose(inputs / resampling_temperature))) +
          biases / resampling_temperature

  The computation of the first term is colocated with the embeddings using
  `transform_fn` in `embedding_ops._embedding_lookup_and_transform`. The second
  term, not the bottleneck, is computed at the worker.

  Args:
    weights: From `rank_sampled_softmax_loss`.
    biases: From `rank_sampled_softmax_loss`.
    inputs: From `rank_sampled_softmax_loss`.
    sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`,
        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
    num_resampled: An `int`. This many values are selected from
        `sampled_values` using the adaptive resampling algorithm. The caller
        must ensure that `num_resampled` is less than the size of
        `sampled_values`.
    resampling_temperature: A scalar `Tensor` with the temperature parameter
        for the adaptive resampling algorithm.
    partition_strategy: From `rank_sampled_softmax_loss`.

  Returns:
    A tuple of (`resampled_candidates`, `true_expected_count`,
        `resampled_expected_count`), similar to `sampled_values` but sampled
        down to `num_resampled` values.
  """
  # This code supports passing a Tensor for num_resampled, but since it is only
  # called with an int, that's what we specify in the arg list. If this
  # function is ever externalized, we should change the doc to support Tensor.

  sampled, true_expected_count, sampled_expected_count = sampled_values

  sampled = math_ops.cast(array_ops.stop_gradient(sampled), dtypes.int64)
  true_expected_count = array_ops.stop_gradient(true_expected_count)
  sampled_expected_count = array_ops.stop_gradient(sampled_expected_count)

  reweighted_inputs = inputs / resampling_temperature

  def logsumexp_logit(embeddings):
    return math_ops.reduce_logsumexp(
        math_ops.matmul(embeddings, reweighted_inputs, transpose_b=True),
        axis=1,
        keep_dims=False)

  # Calling this protected form of embedding_lookup allows co-locating
  # the logsumexp computation with the partitioned weights, which yields
  # a large speedup in practice.
  sampled_logits = embedding_ops._embedding_lookup_and_transform(  # pylint: disable=protected-access
      weights, sampled, partition_strategy, transform_fn=logsumexp_logit)
  sampled_b = array_ops.reshape(
      embedding_ops.embedding_lookup(biases, sampled, partition_strategy), [-1])
  sampled_logits += sampled_b / resampling_temperature

  _, resampled_indices = nn.top_k(sampled_logits, k=num_resampled, sorted=False)
  resampled = array_ops.gather(sampled, indices=resampled_indices)
  resampled_expected_count = array_ops.gather(
      sampled_expected_count, indices=resampled_indices)

  return resampled, true_expected_count, resampled_expected_count
예제 #47
0
    def decoder_fn(time, cell_state, cell_input, cell_output, context_state):
        """Decoder function used in the `dynamic_rnn_decoder` for inference.

        The main difference between this decoder function and the `decoder_fn` in
        `attention_decoder_fn_train` is how `next_cell_input` is calculated. In
        decoder function we calculate the next input by applying an argmax across
        the feature dimension of the output from the decoder. This is a
        greedy-search approach. (Bahdanau et al., 2014) & (Sutskever et al., 2014)
        use beam-search instead.

        Args:
            time: positive integer constant reflecting the current timestep.
            cell_state: state of RNNCell.
            cell_input: input provided by `dynamic_rnn_decoder`.
            cell_output: output of RNNCell.
            context_state: context state provided by `dynamic_rnn_decoder`.

        Returns:
            A tuple (done, next state, next input, emit output, next context state)
            where:

            done: A boolean vector to indicate which sentences has reached a
            `end_of_sequence_id`. This is used for early stopping by the
            `dynamic_rnn_decoder`. When `time>=maximum_length` a boolean vector with
            all elements as `true` is returned.

            next state: `cell_state`, this decoder function does not modify the
            given state.

            next input: The embedding from argmax of the `cell_output` is used as
            `next_input`.

            emit output: If `output_fn is None` the supplied `cell_output` is
            returned, else the `output_fn` is used to update the `cell_output`
            before calculating `next_input` and returning `cell_output`.

            next context state: `context_state`, this decoder function does not
            modify the given context state. The context state could be modified when
            applying e.g. beam search.

        Raises:
            ValueError: if cell_input is not None.

        """
        with ops.name_scope(
                name, "attention_decoder_fn_inference",
            [time, cell_state, cell_input, cell_output, context_state]):
            if cell_input is not None:
                raise ValueError(
                    "Expected cell_input to be None, but saw: %s" % cell_input)
            if cell_output is None:
                # invariant that this is time == 0
                next_input_id = array_ops.ones([
                    batch_size,
                ], dtype=dtype) * (start_of_sequence_id)
                done = array_ops.zeros([
                    batch_size,
                ], dtype=dtypes.bool)
                cell_state = encoder_state
                cell_output = array_ops.zeros([num_decoder_symbols],
                                              dtype=dtypes.float32)
                word_input = array_ops.gather(embeddings, next_input_id)
                naf_triple_id = array_ops.zeros([batch_size, 2], dtype=dtype)
                triple_input = array_ops.gather_nd(imem[1], naf_triple_id)
                cell_input = array_ops.concat([word_input, triple_input],
                                              axis=1)

                # init attention
                attention = _init_attention(encoder_state)
                if imem is not None:
                    context_state = tensor_array_ops.TensorArray(
                        dtype=dtypes.int32,
                        tensor_array_name="output_ids_ta",
                        size=maximum_length,
                        dynamic_size=True,
                        infer_shape=False)
            else:
                # construct attention
                attention = attention_construct_fn(cell_output, attention_keys,
                                                   attention_values)
                if type(attention) is tuple:
                    attention, alignment = attention
                    cell_output = attention
                    alignment = tf.reshape(alignment, [batch_size, -1])
                    selector = selector_fn(cell_output)
                    logit = output_fn(cell_output)
                    word_prob = nn_ops.softmax(logit) * (1 - selector)
                    entity_prob = alignment * selector
                    mask = array_ops.reshape(
                        math_ops.cast(math_ops.greater(
                            tf.reduce_max(word_prob, 1),
                            tf.reduce_max(entity_prob, 1)),
                                      dtype=dtypes.float32), [-1, 1])
                    word_input = mask * array_ops.gather(
                        embeddings,
                        math_ops.cast(math_ops.argmax(word_prob, 1),
                                      dtype=dtype)
                    ) + (1 - mask) * array_ops.gather_nd(
                        imem[0],
                        array_ops.concat([
                            array_ops.reshape(
                                math_ops.range(batch_size, dtype=dtype),
                                [-1, 1]),
                            array_ops.reshape(
                                math_ops.cast(math_ops.argmax(entity_prob, 1),
                                              dtype=dtype), [-1, 1])
                        ],
                                         axis=1))
                    indices = array_ops.concat([
                        array_ops.reshape(
                            math_ops.range(batch_size, dtype=dtype), [-1, 1]),
                        math_ops.cast(1 - mask, dtype=dtype) * tf.reshape(
                            math_ops.cast(math_ops.argmax(alignment, 1),
                                          dtype=dtype), [-1, 1])
                    ],
                                               axis=1)
                    triple_input = array_ops.gather_nd(imem[1], indices)
                    cell_input = array_ops.concat([word_input, triple_input],
                                                  axis=1)
                    mask = array_ops.reshape(math_ops.cast(mask, dtype=dtype),
                                             [-1])
                    input_id = mask * math_ops.cast(
                        math_ops.argmax(word_prob, 1),
                        dtype=dtype) + (mask - 1) * math_ops.cast(
                            math_ops.argmax(entity_prob, 1), dtype=dtype)
                    context_state = context_state.write(time - 1, input_id)
                    done = array_ops.reshape(
                        math_ops.equal(input_id, end_of_sequence_id), [-1])
                    cell_output = logit

                else:
                    cell_output = attention

                    # argmax decoder
                    cell_output = output_fn(cell_output)  # logits
                    next_input_id = math_ops.cast(math_ops.argmax(
                        cell_output, 1),
                                                  dtype=dtype)
                    done = math_ops.equal(next_input_id, end_of_sequence_id)
                    cell_input = array_ops.gather(embeddings, next_input_id)

            # combine cell_input and attention
            next_input = array_ops.concat([cell_input, attention], 1)

            # if time > maxlen, return all true vector
            done = control_flow_ops.cond(
                math_ops.greater(time, maximum_length),
                lambda: array_ops.ones([
                    batch_size,
                ], dtype=dtypes.bool), lambda: done)
            return (done, cell_state, next_input, cell_output, context_state)
예제 #48
0
 def loop_fn(i):
     y = array_ops.gather(output, i, axis=1)
     return gradient_ops.gradients(y, inp)[0]
예제 #49
0
 def tpu_function(sparse):
     # Assumes dense_shape is (2, *)
     looked_up = array_ops.gather(table, sparse.values)
     segment_sum = math_ops.unsorted_segment_sum(
         looked_up, sparse.indices[:, 0], 2)
     return {"sparse": sparse, "segment_sum": segment_sum}
예제 #50
0
  def create_batch(self):
    """Create queues to window and batch time series data.

    Returns:
      A dictionary of Tensors corresponding to the output of `self._reader`
      (from the `time_series_reader` constructor argument), each with shapes
      prefixed by [`batch_size`, `window_size`].
    """
    features = self._reader.read()
    if self._jitter:
      # TODO(agarwal, allenl): Figure out if more jitter is needed here.
      jitter = random_ops.random_uniform(shape=[], maxval=2, dtype=dtypes.int32)
    else:
      jitter = 0
    # To keep things efficient, we pass from the windowing batcher to the
    # batch-of-windows batcher in batches. This avoids the need for huge numbers
    # of threads, but does mean that jitter is only applied occasionally.
    # TODO(allenl): Experiment with different internal passing sizes.
    internal_passing_size = self._batch_size
    features_windowed = input_lib.batch(
        features,
        batch_size=self._window_size * internal_passing_size + jitter,
        enqueue_many=True,
        capacity=(self._queue_capacity_multiplier
                  * internal_passing_size * self._window_size),
        num_threads=self._num_threads)
    raw_features_windowed = features_windowed
    if self._jitter:
      features_windowed = {
          key: value[jitter:]
          for key, value in features_windowed.items()}
    features_windowed = {
        key: array_ops.reshape(
            value,
            array_ops.concat(
                [[internal_passing_size, self._window_size],
                 array_ops.shape(value)[1:]],
                axis=0))
        for key, value in features_windowed.items()}
    batch_and_window_shape = tensor_shape.TensorShape(
        [internal_passing_size, self._window_size])
    for key in features_windowed.keys():
      features_windowed[key].set_shape(
          batch_and_window_shape.concatenate(
              raw_features_windowed[key].get_shape()[1:]))
    # When switching files, we may end up with windows where the time is not
    # decreasing, even if times within each file are sorted (and even if those
    # files are visited in order, when looping back around to the beginning of
    # the first file). This is hard for models to deal with, so we either
    # discard such examples, creating a bias where the beginning and end of the
    # series is under-sampled, or we sort the window, creating large gaps.
    times = features_windowed[feature_keys.TrainEvalFeatures.TIMES]
    if self._discard_out_of_order:
      non_decreasing = math_ops.reduce_all(
          times[:, 1:] >= times[:, :-1], axis=1)
      # Ensure that no more than self._discard_limit complete batches are
      # discarded contiguously (resetting the count when we find a single clean
      # window). This prevents infinite looping when the dataset is smaller than
      # the window size.
      # TODO(allenl): Figure out a way to return informative errors from
      # count_up_to.
      discarded_windows_limiter = variable_scope.variable(
          initial_value=constant_op.constant(0, dtype=dtypes.int64),
          name="discarded_windows_limiter",
          trainable=False,
          collections=[ops.GraphKeys.LOCAL_VARIABLES])
      def _initialized_limit_check():
        return control_flow_ops.cond(
            math_ops.reduce_any(non_decreasing),
            lambda: state_ops.assign(discarded_windows_limiter, 0),
            lambda: discarded_windows_limiter.count_up_to(self._discard_limit))
      discard_limit_op = control_flow_ops.cond(
          state_ops.is_variable_initialized(discarded_windows_limiter),
          _initialized_limit_check,
          lambda: constant_op.constant(0, dtype=dtypes.int64))
      with ops.control_dependencies([discard_limit_op]):
        non_decreasing = array_ops.identity(non_decreasing)
    else:
      _, indices_descending = nn.top_k(
          times, k=array_ops.shape(times)[-1], sorted=True)
      indices = array_ops.reverse(indices_descending, axis=[0])
      features_windowed = {
          key: array_ops.gather(params=value, indices=indices)
          for key, value in features_windowed.items()
      }
      non_decreasing = True
    features_batched = input_lib.maybe_shuffle_batch(
        features_windowed,
        num_threads=self._num_threads,
        seed=self._shuffle_seed,
        batch_size=self._batch_size,
        capacity=self._queue_capacity_multiplier * self._batch_size,
        min_after_dequeue=(self._shuffle_min_after_dequeue_multiplier *
                           self._batch_size),
        keep_input=non_decreasing,
        enqueue_many=True)
    return (features_batched, None)
예제 #51
0
def stack_dynamic_partitions(data, partitions, num_partitions, name=None):
    """Stacks dynamic partitions of a Tensor or RaggedTensor.

  Returns a RaggedTensor `output` with `num_partitions` rows, where the row
  `output[i]` is formed by stacking all slices `data[j1...jN]` such that
  `partitions[j1...jN] = i`.  Slices of `data` are stacked in row-major
  order.

  If `num_partitions` is an `int` (not a `Tensor`), then this is equivalent to
  `tf.ragged.stack(tf.dynamic_partition(data, partitions, num_partitions))`.

  #### Example:

  >>> data           = ['a', 'b', 'c', 'd', 'e']
  >>> partitions     = [  3,   0,   2,   2,   3]
  >>> num_partitions = 5
  >>> tf.ragged.stack_dynamic_partitions(data, partitions, num_partitions)
  <tf.RaggedTensor [[b'b'], [], [b'c', b'd'], [b'a', b'e'], []]>

  Args:
    data: A `Tensor` or `RaggedTensor` containing the values to stack.
    partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the
      partition that each slice of `data` should be added to. `partitions.shape`
      must be a prefix of `data.shape`.  Values must be greater than or equal to
      zero, and less than `num_partitions`. `partitions` is not required to be
      sorted.
    num_partitions: An `int32` or `int64` scalar specifying the number of
      partitions to output.  This determines the number of rows in `output`.
    name: A name prefix for the returned tensor (optional).

  Returns:
    A `RaggedTensor` containing the stacked partitions.  The returned tensor
    has the same dtype as `data`, and its shape is
    `[num_partitions, (D)] + data.shape[partitions.rank:]`, where `(D)` is a
    ragged dimension whose length is the number of data slices stacked for
    each `partition`.
  """
    with ops.name_scope(name, 'SegmentStack',
                        [data, partitions, num_partitions]):
        # Convert inputs to tensors.
        data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data,
                                                                name='data')
        row_splits_dtype = (data.row_splits.dtype if isinstance(
            data, ragged_tensor.RaggedTensor) else None)
        partitions = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            partitions, name='partitions', preferred_dtype=row_splits_dtype)
        num_partitions = ops.convert_to_tensor(
            num_partitions,
            name='num_partitions',
            preferred_dtype=partitions.dtype)
        if row_splits_dtype is not None:
            partitions = math_ops.cast(partitions, row_splits_dtype)
        num_partitions = math_ops.cast(num_partitions, partitions.dtype)

        # Sanity-checks for shapes.
        partitions_rank = partitions.shape.ndims
        if partitions_rank is None:
            raise ValueError('partitions must have known rank.')
        num_partitions.shape.assert_has_rank(0)
        partitions.shape.assert_is_compatible_with(
            data.shape[:partitions_rank])

        if partitions_rank == 0:
            # If partitions is a scalar, then just create a RaggedTensor containing
            # that single the complete `data` value in the specified row.
            return ragged_tensor.RaggedTensor.from_value_rowids(
                values=array_ops.stack([data]),
                value_rowids=array_ops.stack([partitions]),
                nrows=num_partitions,
                validate=False)

        elif partitions_rank == 1:
            # If partitions is a vector (the typical case): we can just use data and
            # partitions as the `values` and `value_rowids` for `from_value_rowids`,
            # as long as we sort them first.
            permutation = sort_ops.argsort(partitions, stable=True)
            value_rowids = array_ops.gather(partitions, permutation)
            values = array_ops.gather(data, permutation)
            check = check_ops.assert_less(
                value_rowids[-1:],
                num_partitions,
                message='partitions must be less than num_partitions')
            with ops.control_dependencies([check]):
                return ragged_tensor.RaggedTensor.from_value_rowids(
                    values, value_rowids, nrows=num_partitions, validate=False)

        else:
            # Handle higher-dimensional partitions via recursion.
            if not isinstance(data, ragged_tensor.RaggedTensor):
                data = ragged_tensor.RaggedTensor.from_tensor(
                    data, row_splits_dtype=partitions.dtype, ragged_rank=1)
            if not isinstance(partitions, ragged_tensor.RaggedTensor):
                partitions = ragged_tensor.RaggedTensor.from_tensor(
                    partitions,
                    row_splits_dtype=partitions.dtype,
                    ragged_rank=max(data.ragged_rank, partitions_rank - 1))
            check = check_ops.assert_equal(
                data.row_splits,
                partitions.row_splits,
                message='data and partitions have incompatible ragged shapes')
            with ops.control_dependencies([check]):
                return stack_dynamic_partitions(data.values, partitions.values,
                                                num_partitions)
예제 #52
0
def stratified_sample(tensors, labels, target_probs, batch_size,
                      init_probs=None, enqueue_many=False, queue_capacity=16,
                      threads_per_queue=1, name=None):
  """Stochastically creates batches based on per-class probabilities.

  This method discards examples. Internally, it creates one queue to amortize
  the cost of disk reads, and one queue to hold the properly-proportioned
  batch.

  Args:
    tensors: List of tensors for data. All tensors are either one item or a
        batch, according to enqueue_many.
    labels: Tensor for label of data. Label is a single integer or a batch,
        depending on enqueue_many. It is not a one-hot vector.
    target_probs: Target class proportions in batch. An object whose type has a
        registered Tensor conversion function.
    batch_size: Size of batch to be returned.
    init_probs: Class proportions in the data. An object whose type has a
        registered Tensor conversion function, or `None` for estimating the
        initial distribution.
    enqueue_many: Bool. If true, interpret input tensors as having a batch
        dimension.
    queue_capacity: Capacity of the large queue that holds input examples.
    threads_per_queue: Number of threads for the large queue that holds input
        examples and for the final queue with the proper class proportions.
    name: Optional prefix for ops created by this function.
  Raises:
    ValueError: enqueue_many is True and labels doesn't have a batch
        dimension, or if enqueue_many is False and labels isn't a scalar.
    ValueError: enqueue_many is True, and batch dimension on data and labels
        don't match.
    ValueError: if probs don't sum to one.
    ValueError: if a zero initial probability class has a nonzero target
        probability.
    TFAssertion: if labels aren't integers in [0, num classes).
  Returns:
    (data_batch, label_batch), where data_batch is a list of tensors of the same
        length as `tensors`

  Example:
    # Get tensor for a single data and label example.
    data, label = data_provider.Get(['data', 'label'])

    # Get stratified batch according to per-class probabilities.
    target_probs = [...distribution you want...]
    [data_batch], labels = tf.contrib.training.stratified_sample(
        [data], label, target_probs)

    # Run batch through network.
    ...
  """
  with ops.name_scope(name, 'stratified_sample', tensors + [labels]):
    tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
    labels = ops.convert_to_tensor(labels)
    target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32)
    # Reduce the case of a single example to that of a batch of size 1.
    if not enqueue_many:
      tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list]
      labels = array_ops.expand_dims(labels, 0)

    # If `init_probs` is `None`, set up online estimation of data distribution.
    if init_probs is None:
      # We use `target_probs` to get the number of classes, so its shape must be
      # fully defined at graph construction time.
      target_probs.get_shape().assert_is_fully_defined()
      init_probs = _estimate_data_distribution(
          labels, target_probs.get_shape().num_elements())
    else:
      init_probs = ops.convert_to_tensor(init_probs, dtype=dtypes.float32)

    # Validate that input is consistent.
    tensor_list, labels, [init_probs, target_probs] = _verify_input(
        tensor_list, labels, [init_probs, target_probs])

    # Check that all zero initial probabilities also have zero target
    # probabilities.
    assert_op = control_flow_ops.Assert(
        math_ops.reduce_all(math_ops.logical_or(
            math_ops.not_equal(init_probs, 0),
            math_ops.equal(target_probs, 0))),
        ['All classes with zero initial probability must also have zero target '
         'probability: ', init_probs, target_probs])
    init_probs = control_flow_ops.with_dependencies([assert_op], init_probs)

    # Calculate acceptance sampling probabilities.
    accept_probs = _calculate_acceptance_probabilities(init_probs, target_probs)
    proportion_rejected = math_ops.reduce_sum((1 - accept_probs) * init_probs)
    accept_probs = control_flow_ops.cond(
        math_ops.less(proportion_rejected, .5),
        lambda: accept_probs,
        lambda: logging_ops.Print(  # pylint: disable=g-long-lambda
            accept_probs, [accept_probs],
            message='Proportion of examples rejected by sampler is high.',
            first_n=10))

    # Make a single queue to hold input examples. Reshape output so examples
    # don't have singleton batch dimension.
    batched = input_ops.batch(tensor_list + [labels],
                              batch_size=1,
                              num_threads=threads_per_queue,
                              capacity=queue_capacity,
                              enqueue_many=True)
    val_list = [array_ops.squeeze(x, [0]) for x in batched[:-1]]
    label = array_ops.squeeze(batched[-1], [0])

    # Set up second queue containing batches that have the desired class
    # proportions.
    cur_prob = array_ops.gather(accept_probs, label)
    keep_input = random_ops.random_uniform([]) < cur_prob
    batched = _conditional_batch(
        val_list + [label],
        keep_input,
        batch_size,
        num_threads=threads_per_queue)
    return batched[:-1], batched[-1]
예제 #53
0
def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
    """Builds nested_split tensors for a tiled `RaggedTensor`.

  Returns a list of split tensors that can be used to construct the
  `RaggedTensor` that tiles `rt_input` as specified by `multiples`.

  Args:
    rt_input: The `RaggedTensor` that is being tiled.
    multiples: A 1-D integer `tensor`, indicating how many times each dimension
      should be repeated.
    const_multiples: Optional constant value for multiples.  Used to skip tiling
      dimensions where `multiples=1`.

  Returns:
    A list of 1-D integer `Tensor`s (one for each ragged dimension in
    `rt_input`).

  #### Example:

  >>> rt = tf.ragged.constant([[1, 2], [3]])
  >>> _tile_ragged_splits(rt, [3, 2])
  [<tf.Tensor: shape=(7,), dtype=int64,
  numpy=array([ 0,  4,  6, 10, 12, 16, 18])>]
  """
    ragged_rank = rt_input.ragged_rank
    nested_splits = rt_input.nested_row_splits

    # projected_splits[src_axis, dst_axis] contains the split points that divide
    # the rows from src_axis in the list of dst_axis values.  E.g.,
    # projected_splits[i, i] = nested_splits[i], and
    # projected_splits[i, i+1] = gather(nested_splits[i+1], nested_splits[i]).
    projected_splits = [{i: nested_splits[i]} for i in range(ragged_rank)]
    for src_axis in range(ragged_rank):
        for dst_axis in range(src_axis + 1, ragged_rank - 1):
            projected_splits[src_axis][dst_axis] = array_ops.gather(
                nested_splits[dst_axis],
                projected_splits[src_axis][dst_axis - 1])

    # For each ragged dimension: nested_splits[axis] -> result_splits[axis].
    result_splits = []
    for axis in range(ragged_rank):
        # Get the length of each row for the input tensor for this dimension.
        input_lengths = nested_splits[axis][1:] - nested_splits[axis][:-1]

        # Multiply those lengths by the `multiples` of dimension axis+1, since
        # each value will be repeated that number of times.
        output_lengths = input_lengths * multiples[axis + 1]

        # Repeat ranges of the row lengths as necessary for them to be tiled in
        # each ragged dimension `d < axis`.  (Start with dimension d=axis-1, and
        # work our way up to dimension d=0.)
        repeats = 1
        for d in range(axis - 1, -1, -1):
            if const_multiples is None or const_multiples[d + 1] != 1:
                splits = projected_splits[d][axis - 1] * repeats
                output_lengths = ragged_util.repeat_ranges(
                    output_lengths, splits, multiples[d + 1])
            repeats *= multiples[d + 1]

        # Tile splits for the outermost (uniform) dimension.
        output_lengths = array_ops.tile(output_lengths, multiples[:1])

        # Convert to splits.
        result_splits.append(ragged_util.lengths_to_splits(output_lengths))

    return result_splits
예제 #54
0
 def func(self, x):
   return array_ops.gather(self.shared_weights, x)
예제 #55
0
def fill_lower_triangular(x,
                          validate_args=False,
                          name="fill_lower_triangular"):
    """Creates a (batch of) lower triangular matrix from a vector of inputs.

  If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
  b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
  `n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`.

  Although the non-batch complexity is O(n^2), large constants and sub-optimal
  vectorization means the complexity of this function is 5x slower than zeroing
  out the upper triangular, i.e., `tf.matrix_band_part(X, -1, 0)`.  This
  function becomes competitive only when several matmul/cholesky/etc ops can be
  ellided in constructing the input.  Example: wiring a fully connected layer as
  a covariance matrix; this function reduces the final layer by 2x and possibly
  reduces the network arch complexity considerably.  In most cases it is better
  to simply build a full matrix and zero out the upper triangular elements,
  e.g., `tril = tf.matrix_band_part(full, -1, 0)`, rather than directly
  construct a lower triangular.

  Example:

  ```python
  fill_lower_triangular([1, 2, 3, 4, 5, 6])
  # Returns: [[1, 0, 0],
  #           [2, 3, 0],
  #           [4, 5, 6]]
  ```

  For comparison, a pure numpy version of this function can be found in
  `distribution_util_test.py`, function `_fill_lower_triangular`.

  Args:
    x: `Tensor` representing lower triangular elements.
    validate_args: `Boolean`, default `False`.  Whether to ensure the shape of
      `x` can be mapped to a lower triangular matrix (controls non-static checks
      only).
    name: `String`. The name to give this op.

  Returns:
    tril: `Tensor` with lower triangular elements filled from `x`.

  Raises:
    ValueError: if shape if `x` has static shape which cannot be mapped to a
      lower triangular matrix.
  """
    # TODO(jvdillon): Replace this code with dedicated op when it exists.
    with ops.name_scope(name, values=(x, )):
        x = ops.convert_to_tensor(x, name="x")
        if (x.get_shape().ndims is not None
                and x.get_shape()[-1].value is not None):
            d = x.get_shape()[-1].value
            # d = n(n+1)/2 implies n is:
            n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
            d_inferred = n * (n + 1) / 2
            if d != d_inferred:
                raise ValueError(
                    "Input cannot be mapped to a lower triangular; "
                    "n*(n+1)/2 = %d != %d" % (d_inferred, d))
            final_shape = x.get_shape()[:-1].concatenate(
                tensor_shape.TensorShape([n, n]))
        else:
            d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
            # d = n(n+1)/2 implies n is:
            n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
                              dtype=dtypes.int32)
            if validate_args:
                is_valid_input_shape = check_ops.assert_equal(
                    n * (n + 1) / 2,
                    d,
                    message="Input cannot be mapped to a lower triangular.")
                n = control_flow_ops.with_dependencies([is_valid_input_shape],
                                                       n)
            final_shape = x.get_shape()[:-1].concatenate(
                tensor_shape.TensorShape([None, None]))

        def tril_ids(n):
            """Internal helper to create vector of linear indices into y."""
            # Build the ids statically; chose 512 because it implies 1MiB.
            if not contrib_framework.is_tensor(n) and n <= 512:
                ids = np.arange(n**2, dtype=np.int32)
                rows = (ids / n).astype(np.int32)  # Implicit floor.
                # We need to stop incrementing the index when we encounter
                # upper-triangular elements.  The idea here is to compute the
                # lower-right number of zeros then by "symmetry" subtract this from the
                # total number of zeros, n(n-1)/2.
                # Then we note that: n(n-1)/2 - (n-r)*(n-r-1)/2 = r(2n-r-1)/2
                offset = (rows * (2 * n - rows - 1) / 2).astype(np.int32)
                # We could also zero out when (rows < cols) == (rows < ids-n*rows).
                # mask = (ids <= (n + 1) * rows).astype(np.int32)
            else:
                ids = math_ops.range(n**2)
                rows = math_ops.cast(ids / n, dtype=dtypes.int32)
                offset = math_ops.cast(rows * (2 * n - rows - 1) / 2,
                                       dtype=dtypes.int32)
            return ids - offset

        # Special-case non-batch case.
        if x.get_shape().ndims == 1:
            y = array_ops.gather(x, array_ops.reshape(tril_ids(n), [n, n]))
            y = array_ops.matrix_band_part(y, -1, 0)
            y.set_shape(y.get_shape().merge_with(final_shape))
            return y

        # Make ids for each batch dim.
        if (x.get_shape().ndims is not None
                and x.get_shape()[:-1].is_fully_defined()):
            batch_shape = np.asarray(x.get_shape()[:-1].as_list(),
                                     dtype=np.int32)
            m = np.prod(batch_shape).astype(np.int32)
        else:
            batch_shape = array_ops.shape(x)[:-1]
            m = array_ops.reduce_prod(array_ops.shape(x)[:-1])
        batch_ids = math_ops.range(m)

        # Assemble the tril_ids into batch,tril_id pairs.
        idx = array_ops.pack([
            array_ops.tile(array_ops.expand_dims(batch_ids, 1), [1, n * n]),
            array_ops.tile(array_ops.expand_dims(tril_ids(n), 0), [m, 1])
        ])
        idx = array_ops.transpose(idx, [1, 2, 0])

        # Gather up, reshape, and return.
        y = array_ops.reshape(x, [-1, d])
        y = array_ops.gather_nd(y, idx)
        y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))
        y = array_ops.matrix_band_part(y, -1, 0)
        y.set_shape(y.get_shape().merge_with(final_shape))
        return y
예제 #56
0
 def _done(t):
   # Note that we don't use tf.control_dependencies since that will not make
   # sure that the computation on GPU has actually finished. So we fetch the
   # first element of the output, and assume that this will not be called on
   # empty tensors.
   return array_ops.gather(array_ops.reshape(t, [-1]), 0)
예제 #57
0
def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
    """Gradient for concat op.

  Args:
    op: An operation.
    grad: `Tensor` or `IndexedSlices` representing the gradients with respect
      to each output of the op.
    start_value_index: An integer index of the first value in the op.inputs.
    end_value_index: An integer index of the last value in the op.inputs.
    dim_index: An interger index of concat_dim or axis parameter in op.inputs.

  Returns:
    Tensors representing the partial gradients with respect to each input
    of the op.

  Raises:
    ValueError: if concat_dim/axis is not statically known.
  """
    def _CreateDenseMaskAndBegin(sizes, concat_dim):
        """Create variables for iteratively slicing a dense gradients tensor."""
        # Since shape is 1-D, shape_of_shape = [rank-of-inputs]
        shape_of_shape = array_ops.shape(sizes[0])
        # Make a vector of length equal to the input's dimensions,
        # with 0's everywhere and 1 in the concat dim position.
        # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
        mask = array_ops.concat([
            array_ops.fill(array_ops.expand_dims(concat_dim, 0), 0), [1],
            array_ops.fill(shape_of_shape - concat_dim - 1, 0)
        ], 0)
        begin = array_ops.fill(shape_of_shape, 0)
        return mask, begin

    def _ExtractInputShapes(inputs):
        """Extract the shapes of a set of input tensors."""
        if not context.in_graph_mode():
            return array_ops.shape_n(inputs)
        sizes = []
        fully_known = True
        for x in inputs:
            input_shape = array_ops.shape(x)
            if not isinstance(input_shape,
                              ops.Tensor) or input_shape.op.type != "Const":
                fully_known = False
                break
            sizes.append(input_shape)

        if fully_known:
            return sizes
        else:
            return array_ops.shape_n(inputs)

    # Degenerate concatenation, just return grad.
    if len(op.inputs) == 2:
        return grad + [None] if end_value_index <= dim_index else [None] + grad

    concat_dim = op.inputs[dim_index]
    input_values = op.inputs[start_value_index:end_value_index]

    out_grads = []
    if isinstance(grad, ops.Tensor):
        if context.in_eager_mode():
            # Using mod here for convenience since concat_dim is already verified
            # in concat implementation to be within the allowed [-rank, rank) range.
            non_neg_concat_dim = (concat_dim._numpy().item(0) %
                                  input_values[0]._rank())  # pylint: disable=protected-access
            # All inputs are guaranteed to be EagerTensors in eager mode
            sizes = pywrap_tensorflow.TFE_Py_TensorShapeSlice(
                input_values, non_neg_concat_dim)
            out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
        else:
            # Using mod here for convenience since concat_dim is already verified
            # in concat implementation to be within the allowed [-rank, rank) range.
            non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])

            # Get the inputs' tensor shapes
            sizes = _ExtractInputShapes(input_values)
            # The magic number of 16 was found through benchmarking a range of sizes
            # on CPUs and a Maxwell TitanX.  A speedup was seen in a large majority of
            # cases when switching implementations at N=16, but it is possible that
            # there will be a small number of performance regressions.
            # pylint: disable=protected-access
            if len(sizes) > 16:
                # extract the size of each input along the concat dimension
                sizes = array_ops.squeeze(
                    array_ops.slice(array_ops.stack(sizes, axis=1),
                                    [non_neg_concat_dim, 0], [1, -1]))
                out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
            else:
                offset = gen_array_ops._concat_offset(non_neg_concat_dim,
                                                      sizes)
                for (begin, size) in zip(offset, sizes):
                    out_grads.append(array_ops.slice(grad, begin, size))
            # pylint: enable=protected-access
    elif isinstance(grad, ops.IndexedSlices):
        # Using mod here for convenience since concat_dim is already verified
        # in concat implementation to be within the allowed [-rank, rank) range.
        non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
        concat_dim_static = tensor_util.constant_value(concat_dim)
        if concat_dim_static is None:
            raise ValueError("Can only compute IndexedSlices gradient with "
                             "statically-known concat_dim")
        if concat_dim_static < 0:
            rank = tensor_util.constant_value(array_ops.rank(input_values[0]))
            if rank is None:
                raise ValueError(
                    "Can only compute IndexedSlices gradient with "
                    "negative concat_dim when first value rank is "
                    "statically-known.")
            concat_dim_static %= rank
        # Get the inputs' tensor shapes
        sizes = [array_ops.shape(x) for x in input_values]
        if concat_dim_static > 0:
            # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices
            # gradients with all the indices, but with grad.values sliced accordingly.
            # This is like the Tensor case, except shape(grad.values)[0] is not equal
            # to shape(sizes[i])[0], since only a subset of the dim-0 values are
            # stored.
            mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim)
            for size in sizes:
                new_values = array_ops.slice(
                    grad.values, begin,
                    array_ops.concat(
                        [[-1], array_ops.slice(size, [1], [-1])], 0))
                out_grads.append(
                    ops.IndexedSlices(new_values, grad.indices, size))
                # Lint complains begin = begin + ...
                begin = math_ops.add(begin, size * mask)
        else:
            # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients
            # only for the relevant indices.
            start = constant_op.constant(0, dtype=grad.indices.dtype)
            for size in sizes:
                size_concat_dim = array_ops.gather(size, non_neg_concat_dim)
                if size_concat_dim.dtype != grad.indices.dtype:
                    size_concat_dim = math_ops.cast(size_concat_dim,
                                                    dtype=grad.indices.dtype)
                end = start + size_concat_dim
                # Compute the 1-D Tensor of indices relevant for this input.
                indices_to_select = array_ops.squeeze(array_ops.where(
                    math_ops.logical_and(grad.indices >= start,
                                         grad.indices < end)),
                                                      squeeze_dims=[1])
                new_indices = array_ops.gather(grad.indices,
                                               indices_to_select) - start
                new_values = array_ops.gather(grad.values, indices_to_select)
                out_grads.append(
                    ops.IndexedSlices(new_values, new_indices, size))
                start = end
    else:
        raise TypeError("Expected Tensor or IndexedSlices, got %s" %
                        type(grad))

    return (out_grads + [None] if end_value_index <= dim_index else [None] +
            out_grads)
예제 #58
0
 def loop_fn(i):
     y = array_ops.gather(output, i)
     return gradient_ops.gradients(y, flat_inputs)
예제 #59
0
 def _inverse(self, y):
     return array_ops.gather(y,
                             array_ops.invert_permutation(self.permutation),
                             axis=-1)
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
  """The reconstruct one or more matrices from their LU decomposition(s).

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
      matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
      X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_reconstruct').

  Returns:
    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
      `lu_reconstruct(*tf.linalg.lu(x))`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  x_reconstructed = tf.linalg.lu_reconstruct(*tf.linalg.lu(x))
  tf.assert_near(x, x_reconstructed)
  # ==> True
  ```

  """
  with ops.name_scope(name or 'lu_reconstruct'):
    lower_upper = ops.convert_to_tensor(
        lower_upper, dtype_hint=dtypes.float32, name='lower_upper')
    perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm')

    assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
    if assertions:
      with ops.control_dependencies(assertions):
        lower_upper = array_ops.identity(lower_upper)
        perm = array_ops.identity(perm)

    shape = array_ops.shape(lower_upper)

    lower = set_diag(
        band_part(lower_upper, num_lower=-1, num_upper=0),
        array_ops.ones(shape[:-1], dtype=lower_upper.dtype))
    upper = band_part(lower_upper, num_lower=0, num_upper=-1)
    x = math_ops.matmul(lower, upper)

    if (lower_upper.shape is None or lower_upper.shape.rank is None or
        lower_upper.shape.rank != 2):
      # We either don't know the batch rank or there are >0 batch dims.
      batch_size = math_ops.reduce_prod(shape[:-2])
      d = shape[-1]
      x = array_ops.reshape(x, [batch_size, d, d])
      perm = array_ops.reshape(perm, [batch_size, d])
      perm = map_fn.map_fn(array_ops.invert_permutation, perm)
      batch_indices = array_ops.broadcast_to(
          math_ops.range(batch_size)[:, array_ops.newaxis], [batch_size, d])
      x = array_ops.gather_nd(x, array_ops.stack([batch_indices, perm],
                                                 axis=-1))
      x = array_ops.reshape(x, shape)
    else:
      x = array_ops.gather(x, array_ops.invert_permutation(perm))

    x.set_shape(lower_upper.shape)
    return x