示例#1
0
  def __call__(self, shape, dtype=None, partition_info=None):
    if dtype is None:
      dtype = self.dtype
    # Check the shape
    if len(shape) < 3 or len(shape) > 5:
      raise ValueError("The tensor to initialize must be at least "
                       "three-dimensional and at most five-dimensional")

    if shape[-2] > shape[-1]:
      raise ValueError("In_filters cannot be greater than out_filters.")

    # Generate a random matrix
    a = random_ops.random_normal([shape[-1], shape[-1]],
                                 dtype=dtype, seed=self.seed)
    # Compute the qr factorization
    q, r = linalg_ops.qr(a, full_matrices=False)
    # Make Q uniform
    d = array_ops.diag_part(r)
    q *= math_ops.sign(d)
    q = q[:shape[-2], :]
    q *= math_ops.sqrt(math_ops.cast(self.gain, dtype=dtype))
    if len(shape) == 3:
      weight = array_ops.scatter_nd([[(shape[0]-1)//2]],
                                    array_ops.expand_dims(q, 0), shape)
    elif len(shape) == 4:
      weight = array_ops.scatter_nd([[(shape[0]-1)//2, (shape[1]-1)//2]],
                                    array_ops.expand_dims(q, 0), shape)
    else:
      weight = array_ops.scatter_nd([[(shape[0]-1)//2, (shape[1]-1)//2,
                                      (shape[2]-1)//2]],
                                    array_ops.expand_dims(q, 0), shape)
    return weight
 def testInvalidShape(self):
   # TODO(apassos) figure out how to unify these errors
   with self.assertRaises(errors.InvalidArgumentError
                          if context.executing_eagerly() else ValueError):
     array_ops.scatter_nd(indices=[0],  # this should be indices=[[0]]
                          updates=[0.0],
                          shape=[1])
  def testEmptyOutputShape1(self):
    indices = array_ops.zeros([2, 2, 2], dtypes.int32)
    updates = array_ops.zeros([2, 2, 2], dtypes.int32)
    shape = constant_op.constant([0, 3, 2], dtypes.int32)

    with self.assertRaisesWithPredicateMatch(
        ValueError, "Indices and updates specified for empty output shape"):
      array_ops.scatter_nd(indices, updates, shape)
示例#4
0
def _ctc_state_trans(label_seq):
  """Compute CTC alignment model transition matrix.

  Args:
    label_seq: tensor of shape [batch_size, max_seq_length]

  Returns:
    tensor of shape [batch_size, states, states] with a state transition matrix
    computed for each sequence of the batch.
  """

  with ops.name_scope("ctc_state_trans"):
    label_seq = ops.convert_to_tensor(label_seq, name="label_seq")
    batch_size = _get_dim(label_seq, 0)
    num_labels = _get_dim(label_seq, 1)

    num_label_states = num_labels + 1
    num_states = 2 * num_label_states

    label_states = math_ops.range(num_label_states)
    blank_states = label_states + num_label_states

    # Start state to first label.
    start_to_label = [[1, 0]]

    # Blank to label transitions.
    blank_to_label = array_ops.stack([label_states[1:], blank_states[:-1]], 1)

    # Label to blank transitions.
    label_to_blank = array_ops.stack([blank_states, label_states], 1)

    # Scatter transitions that don't depend on sequence.
    indices = array_ops.concat(
        [start_to_label, blank_to_label, label_to_blank], 0)
    values = array_ops.ones([_get_dim(indices, 0)])
    trans = array_ops.scatter_nd(
        indices, values, shape=[num_states, num_states])
    trans += linalg_ops.eye(num_states)  # Self-loops.

    # Label to label transitions. Disallow transitions between repeated labels
    # with no blank state in between.
    batch_idx = array_ops.zeros_like(label_states[2:])
    indices = array_ops.stack(
        [batch_idx, label_states[2:], label_states[1:-1]], 1)
    indices = array_ops.tile(
        array_ops.expand_dims(indices, 0), [batch_size, 1, 1])
    batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0]
    indices += array_ops.expand_dims(batch_idx, 1)
    repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:])
    values = 1.0 - math_ops.cast(repeats, dtypes.float32)
    batched_shape = [batch_size, num_states, num_states]
    label_to_label = array_ops.scatter_nd(indices, values, batched_shape)

    return array_ops.expand_dims(trans, 0) + label_to_label
  def testEmptyOutputShape2(self):
    indices = array_ops.placeholder(dtypes.int32, shape=None)
    updates = array_ops.placeholder(dtypes.int32, shape=None)
    shape = constant_op.constant([0, 3, 2], dtypes.int32)

    with self.test_session():
      array_ops.scatter_nd(indices, updates, shape).eval(feed_dict={
          indices: np.zeros(
              [2, 2, 2], dtype=np.int32),
          updates: np.zeros(
              [2, 2, 2], dtype=np.int32)
      })
  def testRank3InvalidShape2(self):
    indices = array_ops.zeros([2, 2, 1], dtypes.int32)
    updates = array_ops.zeros([2, 2], dtypes.int32)
    shape = np.array([2, 2, 2])
    with self.assertRaisesWithPredicateMatch(
        ValueError, "The inner \\d+ dimensions of output\\.shape="):
      array_ops.scatter_nd(indices, updates, shape)

    ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    with self.assertRaisesWithPredicateMatch(
        ValueError, "The inner \\d+ dimensions of ref\\.shape="):
      state_ops.scatter_nd_update(ref, indices, updates)
示例#7
0
def collapse_repeated(labels, seq_length, name=None):
  """Merge repeated labels into single labels.

  Args:
    labels: Tensor of shape [batch, max value in seq_length]
    seq_length: Tensor of shape [batch], sequence length of each batch element.
    name: A name for this `Op`. Defaults to "collapse_repeated_labels".

  Returns:
    A tuple `(collapsed_labels, new_seq_length)` where

    collapsed_labels: Tensor of shape [batch, max_seq_length] with repeated
    labels collapsed and padded to max_seq_length, eg:
    `[[A, A, B, B, A], [A, B, C, D, E]] => [[A, B, A, 0, 0], [A, B, C, D, E]]`

    new_seq_length: int tensor of shape [batch] with new sequence lengths.
  """

  with ops.name_scope(name, "collapse_repeated_labels", [labels, seq_length]):
    labels = ops.convert_to_tensor(labels, name="labels")
    seq_length = ops.convert_to_tensor(seq_length, name="seq_length")

    # Mask labels that don't equal previous label.
    label_mask = array_ops.concat([
        array_ops.ones_like(labels[:, :1], dtypes.bool),
        math_ops.not_equal(labels[:, 1:], labels[:, :-1])
    ],
                                  axis=1)

    # Filter labels that aren't in the original sequence.
    maxlen = _get_dim(labels, 1)
    seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen)
    label_mask = math_ops.logical_and(label_mask, seq_mask)

    # Count masks for new sequence lengths.
    new_seq_len = math_ops.reduce_sum(
        math_ops.cast(label_mask, dtypes.int32), axis=1)

    # Mask indexes based on sequence length mask.
    new_maxlen = math_ops.reduce_max(new_seq_len)
    idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen)

    # Flatten everything and mask out labels to keep and sparse indices.
    flat_labels = array_ops.reshape(labels, [-1])
    flat_label_mask = array_ops.reshape(label_mask, [-1])
    flat_idx_mask = array_ops.reshape(idx_mask, [-1])
    idx = math_ops.range(_get_dim(flat_idx_mask, 0))

    # Scatter to flat shape.
    flat = array_ops.scatter_nd(
        indices=array_ops.expand_dims(
            array_ops.boolean_mask(idx, flat_idx_mask), axis=1),
        updates=array_ops.boolean_mask(flat_labels, flat_label_mask),
        shape=array_ops.shape(flat_idx_mask))

    # Reshape back to square batch.
    batch_size = _get_dim(labels, 0)
    new_shape = [batch_size, new_maxlen]
    return (array_ops.reshape(flat, new_shape),
            math_ops.cast(new_seq_len, seq_length.dtype))
 def testScatterNdRepatedIndicesAdd(self):
   indices = array_ops.zeros([100000, 1], dtypes.int32)
   values = np.random.randn(100000)
   shape = [1]
   with self.test_session():
     val = array_ops.scatter_nd(indices, values, shape).eval()
   self.assertAllClose([np.sum(values)], val)
示例#9
0
def _state_to_olabel_unique(labels, num_labels, states, unique):
  """Sum state log probs to ilabel log probs using unique label indices."""

  num_label_states = _get_dim(labels, 1) + 1
  label_states = states[:, :, 1:num_label_states]
  blank_states = states[:, :, num_label_states:]

  unique_y, unique_idx = unique
  mul_reduce = _sum_states(unique_idx, label_states)

  num_frames = states.shape[0]
  batch_size = states.shape[1]
  num_states = num_label_states - 1
  batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0])
  batch_state_major = array_ops.reshape(
      batch_state_major, [batch_size * num_states, num_frames])
  batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels
  indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1)
  indices = array_ops.reshape(indices, [-1, 1])
  scatter = array_ops.scatter_nd(
      indices=indices,
      updates=batch_state_major,
      shape=[batch_size * num_labels, num_frames])
  scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames])
  scatter = array_ops.where(
      math_ops.equal(scatter, 0.0),
      array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)),
      scatter)
  label_olabels = array_ops.transpose(scatter, [2, 0, 1])
  label_olabels = label_olabels[:, :, 1:]

  blank_olabels = math_ops.reduce_logsumexp(
      blank_states, axis=2, keepdims=True)

  return array_ops.concat([blank_olabels, label_olabels], axis=-1)
  def testEmptyOutputShape3(self):
    indices = array_ops.zeros([0], dtypes.int32)
    updates = array_ops.zeros([0], dtypes.int32)
    shape = constant_op.constant([0], dtypes.int32)
    scatter = array_ops.scatter_nd(indices, updates, shape)

    with self.test_session():
      self.assertEqual(scatter.eval().size, 0)
示例#11
0
 def maybe_sample():
   """Perform scheduled sampling."""
   where_sampling = math_ops.cast(
       array_ops.where(sample_ids > -1), dtypes.int32)
   where_not_sampling = math_ops.cast(
       array_ops.where(sample_ids <= -1), dtypes.int32)
   sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)
   inputs_not_sampling = array_ops.gather_nd(
       base_next_inputs, where_not_sampling)
   sampled_next_inputs = self._embedding_fn(sample_ids_sampling)
   base_shape = array_ops.shape(base_next_inputs)
   return (array_ops.scatter_nd(indices=where_sampling,
                                updates=sampled_next_inputs,
                                shape=base_shape)
           + array_ops.scatter_nd(indices=where_not_sampling,
                                  updates=inputs_not_sampling,
                                  shape=base_shape))
示例#12
0
 def _apply_sparse_shared(self, grad_values, grad_indices, var):
   if var.get_shape()[0] <= self._max_matrix_size or self._gbar_decay != 0.0:
     # The dimension is small enough, we can make the variable dense and
     # do a dense update
     dense_grad = array_ops.scatter_nd(
         array_ops.expand_dims(grad_indices, axis=1), grad_values,
         array_ops.shape(var, out_type=grad_indices.dtype))
     return self._apply_gradient(dense_grad, var)
   return self._apply_gradient(grad_values, var, grad_indices)
 def _runScatterNd(self, indices, updates, shape):
   with self.test_session():
     updates_placeholder = array_ops.placeholder(updates.dtype)
     indices_placeholder = array_ops.placeholder(indices.dtype)
     with self.test_scope():
       output = array_ops.scatter_nd(indices_placeholder, updates_placeholder,
                                     shape)
     feed_dict = {updates_placeholder: updates, indices_placeholder: indices}
     return output.eval(feed_dict=feed_dict)
示例#14
0
def _GatherNdGrad(op, grad):
  ref = op.inputs[0]
  indices = op.inputs[1]
  ref_shape = array_ops.shape(ref, out_type=indices.dtype)
  if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1:
    ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1),
                                 ref_shape)
  else:
    ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
  return [ref_grad, None]
示例#15
0
      def maybe_sample():
        """Perform scheduled sampling."""
        if self._next_input_layer is None:
          return array_ops.where(sample_ids, outputs, base_next_inputs)

        where_sampling = math_ops.cast(
            array_ops.where(sample_ids), dtypes.int32)
        where_not_sampling = math_ops.cast(
            array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32)
        outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
        inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
                                                  where_not_sampling)
        sampled_next_inputs = self._next_input_layer(outputs_sampling)
        base_shape = array_ops.shape(base_next_inputs)
        return (array_ops.scatter_nd(indices=where_sampling,
                                     updates=sampled_next_inputs,
                                     shape=base_shape)
                + array_ops.scatter_nd(indices=where_not_sampling,
                                       updates=inputs_not_sampling,
                                       shape=base_shape))
  def testGradientsRank2SliceUpdate(self):
    indices = constant_op.constant([[1], [0]], dtype=dtypes.int32)
    updates = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64)
    shape = constant_op.constant([2, 2], dtype=dtypes.int32)
    outputs = array_ops.scatter_nd(indices, updates, shape)

    grad_vals = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64)
    grads = gradients_impl.gradients([outputs], [updates], [grad_vals])[0]
    expected_grads = np.array([[1, 2], [3, 4]], dtype=np.float64)
    with self.test_session():
      self.assertAllEqual(expected_grads, grads.eval())
  def testRank3ValidShape(self):
    indices = array_ops.zeros([2, 2, 2], dtypes.int32)
    updates = array_ops.zeros([2, 2, 2], dtypes.int32)
    shape = np.array([2, 2, 2])
    self.assertAllEqual(
        array_ops.scatter_nd(indices, updates, shape).get_shape().as_list(),
        shape)

    ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    self.assertAllEqual(
        state_ops.scatter_nd_update(ref, indices,
                                    updates).get_shape().as_list(), shape)
示例#18
0
  def scheduled_sampling(self, batch_size, sampling_probability, true, estimate):
    with variable_scope.variable_scope("ScheduledEmbedding"):
      # Return -1s where we do not sample, and sample_ids elsewhere
      select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool)
      select_sample = select_sampler.sample(sample_shape=batch_size)
      sample_ids = array_ops.where(
                  select_sample,
                  tf.range(batch_size),
                  gen_array_ops.fill([batch_size], -1))
      where_sampling = math_ops.cast(
          array_ops.where(sample_ids > -1), tf.int32)
      where_not_sampling = math_ops.cast(
          array_ops.where(sample_ids <= -1), tf.int32)
      _estimate = array_ops.gather_nd(estimate, where_sampling)
      _true = array_ops.gather_nd(true, where_not_sampling)

      base_shape = array_ops.shape(true)
      result1 = array_ops.scatter_nd(indices=where_sampling, updates=_estimate, shape=base_shape)
      result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=_true, shape=base_shape)
      result = result1 + result2
      return result1 + result2
示例#19
0
      def maybe_sample():
        """Perform scheduled sampling."""

        def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
          """Concatenate outputs with auxiliary inputs, if they exist."""
          if self._auxiliary_input_tas is None:
            return outputs_

          next_time = time + 1
          auxiliary_inputs = nest.map_structure(
              lambda ta: ta.read(next_time), self._auxiliary_input_tas)
          if indices is not None:
            auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices)
          return nest.map_structure(
              lambda x, y: array_ops.concat((x, y), -1),
              outputs_, auxiliary_inputs)

        if self._next_input_layer is None:
          return array_ops.where(
              sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
              base_next_inputs)

        where_sampling = math_ops.cast(
            array_ops.where(sample_ids), dtypes.int32)
        where_not_sampling = math_ops.cast(
            array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32)
        outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
        inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
                                                  where_not_sampling)
        sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
            self._next_input_layer(outputs_sampling), where_sampling)

        base_shape = array_ops.shape(base_next_inputs)
        return (array_ops.scatter_nd(indices=where_sampling,
                                     updates=sampled_next_inputs,
                                     shape=base_shape)
                + array_ops.scatter_nd(indices=where_not_sampling,
                                       updates=inputs_not_sampling,
                                       shape=base_shape))
  def testExtraIndicesDimensions(self):
    indices = array_ops.zeros([1, 1, 2], dtypes.int32)
    updates = array_ops.zeros([1, 1], dtypes.int32)
    shape = np.array([2, 2])
    scatter = array_ops.scatter_nd(indices, updates, shape)
    self.assertAllEqual(scatter.get_shape().as_list(), shape)
    expected_result = np.zeros([2, 2], dtype=np.int32)
    with self.test_session():
      self.assertAllEqual(expected_result, scatter.eval())

    ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    scatter_update = state_ops.scatter_nd_update(ref, indices, updates)
    self.assertAllEqual(scatter_update.get_shape().as_list(), shape)

    with self.test_session():
      ref.initializer.run()
      self.assertAllEqual(expected_result, scatter_update.eval())
示例#21
0
def _TopKGrad(op, grad, _):
  """Return the gradients for TopK.

  Args:
    op: The TopKOp for which we need to generate gradients.
    grad: Tensor. The gradients passed to the TopKOp.

  Returns:
    A list of two tensors, the first being the gradient w.r.t to the input and
    TopK, and the second being the gradient w.r.t. to the indices (all zero).
  """
  in_shape = array_ops.shape(op.inputs[0])
  ind_shape = array_ops.shape(op.outputs[1])

  # int32 is not supported on GPU hence up-casting
  ind_lastdim = array_ops.gather(
      math_ops.cast(ind_shape, dtypes.int64),
      array_ops.size(ind_shape) - 1)
  # Flatten indices to 2D.
  ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim]))

  in_lastdim = array_ops.gather(
      math_ops.cast(in_shape, dtypes.int64),
      array_ops.size(in_shape) - 1)
  outerdim = array_ops.shape(ind_2d)[0]
  # Compute linear indices (flattened to 1D).
  ind = array_ops.reshape(
      ind_2d + math_ops.cast(
          array_ops.expand_dims(
              math_ops.range(0,
                             math_ops.cast(outerdim, dtypes.int64) * in_lastdim,
                             in_lastdim), -1), dtypes.int32), [-1])

  # Substitute grad to appropriate locations and fill the rest with zeros,
  # finally reshaping it to the original input shape.
  return [
      array_ops.reshape(
          array_ops.scatter_nd(
              array_ops.expand_dims(ind, -1), array_ops.reshape(grad, [-1]),
              [math_ops.reduce_prod(in_shape)]), in_shape),
      array_ops.zeros([], dtype=dtypes.int32)
  ]
    def ScatterUpdateGrads(op, grad):
        var, indices, updates = op.inputs

        updates_grad = array_ops.gather(grad, indices)

        # dynamic stitch approach (this seems to be a bit slower)
        # grad_range = math_ops.range(grad.get_shape()[0].value)
        # var_grad = data_flow_ops.dynamic_stitch(
        #     [grad_range, indices],
        #     [grad, array_ops.zeros(updates.get_shape())])

        if isinstance(grad, ops.IndexedSlices):
            # note: we could use this approach for everything, but the
            # temporary variable approach seems to be slightly faster (but we
            # can't use that on indexedslices)
            var_grad = grad - array_ops.scatter_nd(
                array_ops.expand_dims(indices, 1), updates_grad,
                var.get_shape())
        else:
            shape = tuple(grad.get_shape().as_list())
            dtype = grad.dtype.base_dtype

            with variable_scope.variable_scope(
                    "gradient_vars", reuse=variable_scope.AUTO_REUSE):
                var_grad = variable_scope.get_variable(
                    "tmp" + "_%s" * (len(grad.get_shape()) + 1) % (
                        shape + (dtype.name,)),
                    shape=shape, dtype=dtype, trainable=False,
                    collections=["gradient_vars"])

            var_grad = state_ops.assign(var_grad, grad)
            var_grad = state_ops.scatter_update(
                var_grad, indices, array_ops.zeros_like(updates))

            # we need to force a copy so that any future assignments to the
            # variable will not affect the value we return here
            # TODO: check if this is still necessary in TensorFlow 2.0
            var_grad = var_grad + 0

        return var_grad, None, updates_grad
 def testUndefinedOutputShape(self):
   indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
   updates = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
   shape = array_ops.placeholder(dtypes.int32, shape=[None])
   array_ops.scatter_nd(indices, updates, shape)
示例#24
0
def _GatherNdGrad(op, grad):
  ref = op.inputs[0]
  indices = op.inputs[1]
  ref_shape = array_ops.shape(ref, out_type=indices.dtype)
  ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
  return [ref_grad, None]
示例#25
0
def scheduled_sampling_vocab_dist(hps,
                                  sampling_probability,
                                  output,
                                  embedding,
                                  inp,
                                  alpha=0):
    # borrowed ideas from https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTrainingHelper

    def soft_argmax(alpha, output):
        # alpha_exp = tf.exp(alpha * output) # (batch_size, vocab_size)
        # one_hot_scores = alpha_exp / tf.reshape(tf.reduce_sum(alpha_exp, axis=1),[-1,1]) #(batch_size, vocab_size)
        one_hot_scores = tf.nn.softmax(alpha * output)
        return one_hot_scores

    def soft_top_k(alpha, output, K):
        copy = tf.identity(output)
        p = []
        arg_top_k = []
        for k in range(K):
            sargmax = soft_argmax(alpha, copy)
            copy = (1 - sargmax) * copy
            p.append(tf.reduce_sum(sargmax * output, axis=1))
            arg_top_k.append(sargmax)

        return tf.stack(p, axis=1), tf.stack(arg_top_k)

    with variable_scope.variable_scope("ScheduledEmbedding"):
        # Return -1s where we did not sample, and sample_ids elsewhere
        select_sampler = bernoulli.Bernoulli(probs=sampling_probability,
                                             dtype=tf.bool)
        select_sample = select_sampler.sample(sample_shape=hps.batch_size)
        sample_id_sampler = categorical.Categorical(
            probs=output
        )  # equals to argmax{ Multinomial(output, total_count=1) }, our greedy search selection
        sample_ids = array_ops.where(select_sample,
                                     sample_id_sampler.sample(seed=123),
                                     gen_array_ops.fill([hps.batch_size], -1))

        where_sampling = math_ops.cast(array_ops.where(sample_ids > -1),
                                       tf.int32)
        where_not_sampling = math_ops.cast(array_ops.where(sample_ids <= -1),
                                           tf.int32)

        if hps.greedy_scheduled_sampling:
            sample_ids = tf.argmax(output, axis=1, output_type=tf.int32)

        sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)
        inputs_not_sampling = array_ops.gather_nd(inp, where_not_sampling)

        if hps.E2EBackProp:
            if hps.hard_argmax:
                greedy_search_prob, greedy_search_sample = tf.nn.top_k(
                    output, k=hps.k)  # (batch_size, k)
                greedy_search_prob_normalized = greedy_search_prob / tf.reshape(
                    tf.reduce_sum(greedy_search_prob, axis=1), [-1, 1])
                greedy_embedding = tf.nn.embedding_lookup(
                    embedding, greedy_search_sample)
                normalized_embedding = tf.multiply(
                    tf.reshape(greedy_search_prob_normalized,
                               [hps.batch_size, hps.k, 1]), greedy_embedding)
                e2e_embedding = tf.reduce_mean(normalized_embedding, axis=1)
            else:
                e = []
                greedy_search_prob, greedy_search_sample = soft_top_k(
                    alpha, output,
                    K=hps.k)  # (batch_size, k), (k, batch_size, vocab_size)
                greedy_search_prob_normalized = greedy_search_prob / tf.reshape(
                    tf.reduce_sum(greedy_search_prob, axis=1), [-1, 1])

                for _ in range(hps.k):
                    a_k = greedy_search_sample[_]
                    e_k = tf.matmul(
                        tf.reshape(greedy_search_prob_normalized[:, _],
                                   [-1, 1]) * a_k, embedding)
                    e.append(e_k)
                e2e_embedding = tf.reduce_sum(e,
                                              axis=0)  # (batch_size, emb_dim)
            sampled_next_inputs = array_ops.gather_nd(e2e_embedding,
                                                      where_sampling)
        else:
            if hps.hard_argmax:
                sampled_next_inputs = tf.nn.embedding_lookup(
                    embedding, sample_ids_sampling)
            else:  # using soft armax (greedy) proposed in: https://arxiv.org/abs/1704.06970
                # alpha_exp = tf.exp(alpha * (output_not_extended + G)) # (batch_size, vocab_size)
                # one_hot_scores = alpha_exp / tf.reduce_sum(alpha_exp, axis=1) #(batch_size, vocab_size)
                one_hot_scores = soft_argmax(
                    alpha, output)  # (batch_size, vocab_size)
                soft_argmax_embedding = tf.matmul(
                    one_hot_scores, embedding)  # (batch_size, emb_size)
                sampled_next_inputs = array_ops.gather_nd(
                    soft_argmax_embedding, where_sampling)

        base_shape = array_ops.shape(inp)
        result1 = array_ops.scatter_nd(indices=where_sampling,
                                       updates=sampled_next_inputs,
                                       shape=base_shape)
        result2 = array_ops.scatter_nd(indices=where_not_sampling,
                                       updates=inputs_not_sampling,
                                       shape=base_shape)
        return result1 + result2
 def scatter_nd(self, indices, updates, shape, input_=None):
   del input_  # input_ is not used in scatter_nd
   return array_ops.scatter_nd(indices, updates, shape)
示例#27
0
def _GatherNdGrad(op, grad):
    ref = op.inputs[0]
    indices = op.inputs[1]
    ref_shape = array_ops.shape(ref, out_type=indices.dtype)
    ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
    return [ref_grad, None]
示例#28
0
    def _get_update_op(self,
                       score_drop,
                       score_grow,
                       mask,
                       weights,
                       reinit_when_same=False):
        """Prunes+grows connections, all tensors same shape."""
        old_dtype = mask.dtype
        mask_casted = math_ops.cast(mask, dtypes.float32)
        n_total = array_ops.size(score_drop)
        n_ones = math_ops.cast(math_ops.reduce_sum(mask_casted),
                               dtype=dtypes.int32)
        n_prune = math_ops.cast(
            math_ops.cast(n_ones, dtype=dtypes.float32) * self.drop_fraction,
            dtypes.int32)
        n_keep = n_ones - n_prune

        # Sort the entire array since the k needs to be constant for TPU.
        _, sorted_indices = nn_ops.top_k(array_ops.reshape(score_drop, [-1]),
                                         k=n_total)
        sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1)
        # We will have zeros after having `n_keep` many ones.
        new_values = array_ops.where(
            math_ops.range(n_total) < n_keep,
            array_ops.ones_like(sorted_indices, dtype=mask_casted.dtype),
            array_ops.zeros_like(sorted_indices, dtype=mask_casted.dtype))
        mask1 = array_ops.scatter_nd(sorted_indices_ex, new_values,
                                     new_values.shape)
        # Flatten the scores
        score_grow = array_ops.reshape(score_grow, [-1])
        # Set scores of the enabled connections(ones) to min(s) - 1, so that they
        # have the lowest scores.
        score_grow_lifted = array_ops.where(
            math_ops.equal(mask1, 1),
            array_ops.ones_like(mask1) * (math_ops.reduce_min(score_grow) - 1),
            score_grow)
        _, sorted_indices = nn_ops.top_k(score_grow_lifted, k=n_total)
        sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1)
        new_values = array_ops.where(
            math_ops.range(n_total) < n_prune,
            array_ops.ones_like(sorted_indices, dtype=mask_casted.dtype),
            array_ops.zeros_like(sorted_indices, dtype=mask_casted.dtype))
        mask2 = array_ops.scatter_nd(sorted_indices_ex, new_values,
                                     new_values.shape)
        # Ensure masks are disjoint.
        assert_op = control_flow_ops.Assert(
            math_ops.equal(math_ops.reduce_sum(mask1 * mask2), 0.),
            [mask1, mask2])

        with ops.control_dependencies([assert_op]):
            # Let's set the weights of the growed connections.
            mask2_reshaped = array_ops.reshape(mask2, mask.shape)
        # Set the values of the new connections.
        grow_tensor = self.get_grow_tensor(weights, self._grow_init)
        if reinit_when_same:
            # If dropped and grown, we re-initialize.
            new_connections = math_ops.equal(mask2_reshaped, 1)
        else:
            new_connections = math_ops.logical_and(
                math_ops.equal(mask2_reshaped, 1),
                math_ops.equal(mask_casted, 0))
        new_weights = array_ops.where(new_connections, grow_tensor, weights)
        weights_update = state_ops.assign(weights, new_weights)
        # Ensure there is no momentum value for new connections
        reset_op = self.reset_momentum(weights, new_connections)

        with ops.control_dependencies([weights_update, reset_op]):
            mask_combined = array_ops.reshape(mask1 + mask2, mask.shape)
        mask_combined = math_ops.cast(mask_combined, dtype=old_dtype)
        new_mask = state_ops.assign(mask, mask_combined)
        return new_mask
示例#29
0
def scheduled_sampling(hps,
                       sampling_probability,
                       output,
                       embedding,
                       inp,
                       alpha=0):
    # borrowed ideas from https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTrainingHelper
    vocab_size = embedding.get_shape()[0].value

    def soft_argmax(alpha, output):
        alpha_exp = tf.exp(alpha * output)  # (batch_size, vocab_size)
        one_hot_scores = alpha_exp / tf.reshape(
            tf.reduce_sum(alpha_exp, axis=1),
            [-1, 1])  #(batch_size, vocab_size)
        return one_hot_scores

    def soft_top_k(alpha, output, K):
        copy = tf.identity(output)
        p = []
        arg_top_k = []
        for k in range(K):
            sargmax = soft_argmax(alpha, copy)
            copy = (1 - sargmax) * copy
            p.append(tf.reduce_sum(sargmax * output, axis=1))
            # replace oov with unk if necessary
            mask = tf.equal(tf.reduce_max(sargmax, axis=1),
                            tf.reduce_max(sargmax[:, 0:vocab_size], axis=1))
            sargmax_truncated = tf.where(
                mask, sargmax[:, 0:vocab_size],
                tf.stack([
                    tf.one_hot(0, vocab_size) for _ in range(hps.batch_size)
                ]))

            arg_top_k.append(sargmax_truncated)

        return p, tf.stack(arg_top_k)

    with variable_scope.variable_scope("ScheduledEmbedding"):
        # Return -1s where we did not sample, and sample_ids elsewhere
        select_sampler = bernoulli.Bernoulli(probs=sampling_probability,
                                             dtype=tf.bool)
        select_sample = select_sampler.sample(sample_shape=hps.batch_size)
        sample_id_sampler = categorical.Categorical(
            probs=output
        )  # equals to argmax{ Multinomial(output, total_count=1) }, our greedy search selection
        sample_ids = array_ops.where(select_sample,
                                     sample_id_sampler.sample(seed=123),
                                     gen_array_ops.fill([hps.batch_size], -1))

        where_sampling = math_ops.cast(array_ops.where(sample_ids > -1),
                                       tf.int32)
        where_not_sampling = math_ops.cast(array_ops.where(sample_ids <= -1),
                                           tf.int32)

        if hps.greedy_scheduled_sampling:
            sample_ids = tf.argmax(output, axis=1, output_type=tf.int32)

        sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)

        cond = tf.less(sample_ids_sampling, vocab_size)  # replace oov with unk
        sample_ids_sampling = tf.cast(cond, tf.int32) * sample_ids_sampling
        inputs_not_sampling = array_ops.gather_nd(inp, where_not_sampling)

        if hps.E2EBackProp:
            if hps.hard_argmax:
                greedy_search_prob, greedy_search_sample = tf.nn.top_k(
                    output, k=hps.k)  # (batch_size, k)
                greedy_search_prob_normalized = greedy_search_prob / tf.reshape(
                    tf.reduce_sum(greedy_search_prob, axis=1), [-1, 1])

                cond = tf.less(greedy_search_sample,
                               vocab_size)  # replace oov with unk
                greedy_search_sample = tf.cast(cond,
                                               tf.int32) * greedy_search_sample

                greedy_embedding = tf.nn.embedding_lookup(
                    embedding, greedy_search_sample)
                normalized_embedding = tf.multiply(
                    tf.reshape(greedy_search_prob_normalized,
                               [hps.batch_size, hps.k, 1]), greedy_embedding)
                e2e_embedding = tf.reduce_mean(normalized_embedding, axis=1)
            else:
                e = []
                greedy_search_prob, greedy_search_sample = soft_top_k(
                    alpha, output, K=hps.k)  # (batch_size, k), (k, vocab_size)
                greedy_search_prob_normalized = greedy_search_prob / tf.reshape(
                    tf.reduce_sum(greedy_search_prob, axis=1), [-1, 1])

                for _ in range(hps.k):
                    a_k = greedy_search_sample[_]
                    e_k = tf.matmul(
                        tf.reshape(greedy_search_prob_normalized[:, _],
                                   [-1, 1]) * a_k, embedding)
                    e.append(e_k)
                e2e_embedding = tf.reduce_sum(e,
                                              axis=0)  # (batch_size, emb_dim)
            sampled_next_inputs = array_ops.gather_nd(e2e_embedding,
                                                      where_sampling)
        else:
            if hps.hard_argmax:
                sampled_next_inputs = tf.nn.embedding_lookup(
                    embedding, sample_ids_sampling)
            else:  # using soft armax (greedy) proposed in: https://arxiv.org/abs/1704.06970
                if not hps.greedy_scheduled_sampling:
                    # Gumbel reparametrization trick: https://arxiv.org/abs/1704.06970
                    U = tf.random_uniform(
                        (hps.batch_size, vocab_size), 10e-12,
                        (1 - 10e-12))  # add a small number to avoid log(0)
                    G = -tf.log(-tf.log(U))
                else:
                    G = tf.zeros((hps.batch_size, vocab_size))
                #alpha_exp = tf.exp(alpha * (output_not_extended + G)) # (batch_size, vocab_size)
                #one_hot_scores = alpha_exp / tf.reduce_sum(alpha_exp, axis=1) #(batch_size, vocab_size)
                one_hot_scores = soft_argmax(
                    alpha, (output + G))  #(batch_size, vocab_size)
                sampled_next_inputs = tf.matmul(
                    one_hot_scores, embedding)  #(batch_size, emb_size)

        base_shape = array_ops.shape(inp)
        result1 = array_ops.scatter_nd(indices=where_sampling,
                                       updates=sampled_next_inputs,
                                       shape=base_shape)
        result2 = array_ops.scatter_nd(indices=where_not_sampling,
                                       updates=inputs_not_sampling,
                                       shape=base_shape)
        return result1 + result2
示例#30
0
    def _groupwise_dnn_v2(features, labels, mode, params, config):
        """Defines the dnn for groupwise scoring functions."""
        with ops.name_scope('transform'):
            context_features, per_example_features = _call_transform_fn(
                features, mode)

        def _score_fn(context_features, group_features, reuse):
            with variable_scope.variable_scope('group_score', reuse=reuse):
                return group_score_fn(context_features, group_features, mode,
                                      params, config)

        # Scatter/Gather per-example scores through groupwise comparison. Each
        # instance in a mini-batch will form a number of groups. Each groups of
        # examples are scored by 'score_fn' and socres for individual examples
        # accumulated over groups.
        with ops.name_scope('groupwise_dnn_v2'):
            with ops.name_scope('infer_sizes'):
                if labels is not None:
                    batch_size, list_size = array_ops.unstack(
                        array_ops.shape(labels))
                    is_valid = utils.is_label_valid(labels)
                else:
                    # Infer batch_size and list_size from a feature.
                    example_tensor_shape = array_ops.shape(
                        next(six.itervalues(per_example_features)))
                    batch_size = example_tensor_shape[0]
                    list_size = example_tensor_shape[1]
                    is_valid = utils.is_label_valid(
                        array_ops.ones([batch_size, list_size]))
            if batch_size is None or list_size is None:
                raise ValueError('Invalid batch_size=%s or list_size=%s' %
                                 (batch_size, list_size))

            # For each example feature, assume the shape is [batch_size, list_size,
            # feature_size], the groups are formed along the 2nd dim. Each group has a
            # 'group_size' number of indices in [0, list_size). Based on these
            # indices, we can gather the example feature into a sub-tensor for each
            # group. The total number of groups we have for a mini-batch is batch_size
            # * num_groups. Inside each group, we have a 'group_size' number of
            # examples.
            indices, mask = _form_group_indices_nd(
                is_valid,
                group_size,
                shuffle=(mode != model_fn.ModeKeys.PREDICT))
            num_groups = array_ops.shape(mask)[1]

            with ops.name_scope('group_features'):
                # For context features, We have shape [batch_size * num_groups, ...].
                large_batch_context_features = {}
                for name, value in six.iteritems(context_features):
                    # [batch_size, 1, ...].
                    value = array_ops.expand_dims(value, axis=1)
                    # [batch_size, num_groups, ...].
                    value = array_ops.gather(value,
                                             array_ops.zeros([num_groups],
                                                             dtypes.int32),
                                             axis=1)
                    # [batch_size * num_groups, ...]
                    large_batch_context_features[
                        name] = utils.reshape_first_ndims(
                            value, 2, [batch_size * num_groups])

                # For example feature, we have shape [batch_size * num_groups,
                # group_size, ...].
                large_batch_group_features = {}
                for name, value in six.iteritems(per_example_features):
                    # [batch_size, num_groups, group_size, ...].
                    value = array_ops.gather_nd(value, indices)
                    # [batch_size * num_groups, group_size, ...].
                    large_batch_group_features[
                        name] = utils.reshape_first_ndims(
                            value, 3, [batch_size * num_groups, group_size])

            # Do the inference and get scores for the large batch.
            # [batch_size * num_groups, group_size].
            scores = _score_fn(large_batch_context_features,
                               large_batch_group_features,
                               reuse=False)

            with ops.name_scope('accumulate_scores'):
                scores = array_ops.reshape(
                    scores, [batch_size, num_groups, group_size])
                # Reset invalid scores to 0 based on mask.
                scores = array_ops.where(
                    array_ops.gather(array_ops.expand_dims(mask, 2),
                                     array_ops.zeros([group_size],
                                                     dtypes.int32),
                                     axis=2), scores,
                    array_ops.zeros_like(scores))
                # [batch_size, num_groups, group_size].
                list_scores = array_ops.scatter_nd(indices, scores,
                                                   [batch_size, list_size])
                # Use average.
                list_scores /= math_ops.to_float(group_size)

        if mode == model_fn.ModeKeys.PREDICT:
            return list_scores
        else:
            features.update(context_features)
            features.update(per_example_features)
            return list_scores
示例#31
0
 def scatter_nd(self, indices, updates, shape, input_=None):
   del input_  # input_ is not used in scatter_nd
   return array_ops.scatter_nd(indices, updates, shape)
示例#32
0
 def _replicate_rows(tensor, multiple):
   tensor_shape = tensor.shape.as_list()
   expanded_shape = [tensor_shape[0] * multiple, tensor_shape[1]]
   indices = constant_op.constant(_generate_indices(tensor_shape[0], multiple))
   return array_ops.scatter_nd(indices, _tile_rows(tensor, multiple),
                               expanded_shape)
 def testUndefinedUpdatesShape(self):
   indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
   updates = array_ops.placeholder(dtypes.int32, shape=None)
   shape = constant_op.constant([2, 2, 2], dtypes.int32)
   array_ops.scatter_nd(indices, updates, shape)
示例#34
0
def collapse_repeated(labels, seq_length, name=None):
    """Merge repeated labels into single labels.

  Args:
    labels: Tensor of shape (batch, max value in seq_length)
    seq_length: Tensor of shape (batch), sequence length of each batch element.
    name: A name for this `Op`. Defaults to "collapse_repeated_labels".

  Returns:
    tuple of Tensor of shape (batch, max_seq_length) with repeated labels
    collapsed and padded to max_seq_length, eg:
        [[A, A, B, B, A],
         [A, B, C, D, E]] => [[A, B, A, 0, 0],
                              [A, B, C, D, E]]
    and int tensor of shape [batch] with new sequence lengths.
  """

    with ops.name_scope(name, "collapse_repeated_labels",
                        [labels, seq_length]):
        labels = ops.convert_to_tensor(labels, name="labels")
        seq_length = ops.convert_to_tensor(seq_length, name="seq_length")

        # Mask labels that don't equal previous label.
        label_mask = array_ops.concat([
            array_ops.ones_like(labels[:, :1], dtypes.bool),
            math_ops.not_equal(labels[:, 1:], labels[:, :-1])
        ],
                                      axis=1)

        # Filter labels that aren't in the original sequence.
        maxlen = _get_dim(labels, 1)
        seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen)
        label_mask = math_ops.logical_and(label_mask, seq_mask)

        # Count masks for new sequence lengths.
        new_seq_len = math_ops.reduce_sum(math_ops.cast(
            label_mask, dtypes.int32),
                                          axis=1)

        # Mask indexes based on sequence length mask.
        new_maxlen = math_ops.reduce_max(new_seq_len)
        idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen)

        # Flatten everything and mask out labels to keep and sparse indices.
        flat_labels = array_ops.reshape(labels, [-1])
        flat_label_mask = array_ops.reshape(label_mask, [-1])
        flat_idx_mask = array_ops.reshape(idx_mask, [-1])
        idx = math_ops.range(_get_dim(flat_idx_mask, 0))

        # Scatter to flat shape.
        flat = array_ops.scatter_nd(indices=array_ops.expand_dims(
            array_ops.boolean_mask(idx, flat_idx_mask), axis=1),
                                    updates=array_ops.boolean_mask(
                                        flat_labels, flat_label_mask),
                                    shape=array_ops.shape(flat_idx_mask))

        # Reshape back to square batch.
        batch_size = _get_dim(labels, 0)
        new_shape = [batch_size, new_maxlen]
        return (array_ops.reshape(flat, new_shape),
                math_ops.cast(new_seq_len, seq_length.dtype))