Example #1
0
 def __call__(self, inputs, state, scope=None):
     state = array_ops.squeeze(state, axis=[1])
     batch_size = array_ops.shape(inputs)[0]
     b_indices = math_ops.range(batch_size)
     indices = array_ops.stack([b_indices, state], axis=1)
     new_tags = array_ops.expand_dims(gen_array_ops.gather_nd(inputs, indices), axis=-1)
     return new_tags, new_tags
Example #2
0
    def __call__(self, inputs, state, scope=None):
        """Build the CrfDecodeBackwardRnnCell.
        Args:
          inputs: A [batch_size, num_tag * K] matrix of
                backpointer of next step (in time order).
          state: A [batch_size, K] matrix of tag index of next step.
          scope: Unused variable scope of this cell.
        Returns:
          new_tags, new_tags: A pair of [batch_size, K]
            tensors containing the new tag indices.
        """
        # state = array_ops.squeeze(state, axis=[1])                # [B]

        batch_size = array_ops.shape(inputs)[0]
        b_indices = math_ops.range(batch_size)  # [B]
        b_indices = tf.tile(array_ops.expand_dims(b_indices, axis=0),
                            [array_ops.shape(state)[1], 1])  # [K, B]
        b_indices = tf.transpose(b_indices, perm=[1, 0])  # [B, K]

        indices = array_ops.stack(
            [tf.reshape(b_indices, [-1]),
             tf.reshape(state, [-1])], axis=1)  # [B * K, 2]
        new_tags = array_ops.reshape(
            gen_array_ops.gather_nd(inputs, indices),  # [B * K]
            [batch_size, -1])  # [B, K]

        return new_tags, new_tags
Example #3
0
def gather_along_second_axis(data, indices):
    """Super-weird way to select by a dimension.
  This can be refactored into a single call with an axis argument.
  """
    ndims = len(data.get_shape().as_list())
    shape = array_ops.shape(data)
    re_shape = [shape[0] * shape[1]]
    indices = array_ops.reshape(indices, re_shape)
    for idx in range(2, ndims):
        re_shape.append(shape[idx])
    data = array_ops.reshape(data, re_shape)
    batch_offset = math_ops.range(0, array_ops.shape(data)[0])
    flat_indices = array_ops.stack([batch_offset, indices], axis=1)
    two_d = gen_array_ops.gather_nd(data, flat_indices)
    three_d = gen_array_ops.reshape(two_d, [shape[0], shape[1], -1])
    return three_d
Example #4
0
def gather_along_second_axis(data, indices):
  ndims = len(data.get_shape().as_list())
  shape = array_ops.shape(data)
  re_shape = [shape[0] * shape[1]]
  indices = array_ops.reshape(indices, re_shape)
  for idx in range(2, ndims):
    re_shape.append(shape[idx])
  data = array_ops.reshape(data, re_shape)
  batch_offset = math_ops.range(0, array_ops.shape(data)[0])
  flat_indices = array_ops.stack([batch_offset, indices], axis=1)
  two_d = gen_array_ops.gather_nd(data, flat_indices)

  if ndims == 4:
    three_d = gen_array_ops.reshape(two_d, [shape[0], shape[1], -1])
  elif ndims == 3:
    three_d = gen_array_ops.reshape(two_d, [shape[0], shape[1]])
  return three_d
Example #5
0
        def __call__(self, inputs, state, scope=None):
            """Build the CrfDecodeBackwardRnnCell.
            Args:
              inputs: [batch_size, num_tags], backpointer of next step (in time order).
              state: [batch_size, 1], next position's tag index.
              scope: Unused variable scope of this cell.
            Returns:
              new_tags, new_tags: A pair of [batch_size, num_tags]
                tensors containing the new tag indices.
            """
            state = array_ops.squeeze(state, axis=[1])  # [B]
            batch_size = array_ops.shape(inputs)[0]
            b_indices = math_ops.range(batch_size)  # [B]
            indices = array_ops.stack([b_indices, state], axis=1)  # [B, 2]
            new_tags = array_ops.expand_dims(
                gen_array_ops.gather_nd(inputs, indices),  # [B]
                axis=-1)  # [B, 1]

            return new_tags, new_tags
Example #6
0
  def __call__(self, inputs, state, scope=None):
    """Build the CrfDecodeBackwardRnnCell.

    Args:
      inputs: [batch_size, num_tags], backpointer of next step (in time order).
      state: [batch_size, 1], next position's tag index.
      scope: Unused variable scope of this cell.

    Returns:
      new_tags, new_tags: A pair of [batch_size, num_tags]
        tensors containing the new tag indices.
    """
    state = array_ops.squeeze(state, axis=[1])                # [B]
    batch_size = array_ops.shape(inputs)[0]
    b_indices = math_ops.range(batch_size)                    # [B]
    indices = array_ops.stack([b_indices, state], axis=1)     # [B, 2]
    new_tags = array_ops.expand_dims(
        gen_array_ops.gather_nd(inputs, indices),             # [B]
        axis=-1)                                              # [B, 1]

    return new_tags, new_tags