Exemple #1
0
  def _multihead_attention(self, memory):
    """Perform multi-head attention from 'Attention is All You Need'.

    Implementation of the attention mechanism from
    https://arxiv.org/abs/1706.03762.

    Args:
      memory: Memory tensor to perform attention on.

    Returns:
      new_memory: New memory tensor.
    """
    key_size = self._key_size
    value_size = self._head_size

    qkv_size = 2 * key_size + value_size
    total_size = qkv_size * self._num_heads  # Denote as F.
    qkv = basic.BatchApply(basic.Linear(total_size))(memory)
    qkv = basic.BatchApply(layer_norm.LayerNorm())(qkv)

    mem_slots = memory.get_shape().as_list()[1]  # Denoted as N.

    # [B, N, F] -> [B, N, H, F/H]
    qkv_reshape = basic.BatchReshape([mem_slots, self._num_heads,
                                      qkv_size])(qkv)

    # [B, N, H, F/H] -> [B, H, N, F/H]
    qkv_transpose = tf.transpose(qkv_reshape, [0, 2, 1, 3])
    q, k, v = tf.split(qkv_transpose, [key_size, key_size, value_size], -1)

    q *= qkv_size ** -0.5
    dot_product = tf.matmul(q, k, transpose_b=True)  # [B, H, N, N]
    weights = tf.nn.softmax(dot_product)

    output = tf.matmul(weights, v)  # [B, H, N, V]

    # [B, H, N, V] -> [B, N, H, V]
    output_transpose = tf.transpose(output, [0, 2, 1, 3])

    # [B, N, H, V] -> [B, N, H * V]
    new_memory = basic.BatchFlatten(preserve_dims=2)(output_transpose)
    return new_memory
Exemple #2
0
    def _affine_grid_warper_inverse(inputs):
      """Assembles network to compute inverse affine transformation.

      Each `inputs` row potentailly contains [a, b, tx, c, d, ty]
      corresponding to an affine matrix:

        A = [a, b, tx],
            [c, d, ty]

      We want to generate a tensor containing the coefficients of the
      corresponding inverse affine transformation in a constraints-aware
      fashion.
      Calling M:

        M = [a, b]
            [c, d]

      the affine matrix for the inverse transform is:

         A_in = [M^(-1), M^-1 * [-tx, -tx]^T]

      where

        M^(-1) = (ad - bc)^(-1) * [ d, -b]
                                  [-c,  a]

      Args:
        inputs: Tensor containing a batch of transformation parameters.

      Returns:
        A tensorflow graph performing the inverse affine transformation
        parametrized by the input coefficients.
      """
      batch_size = tf.expand_dims(tf.shape(inputs)[0], 0)
      constant_shape = tf.concat([batch_size, tf.convert_to_tensor((1,))], 0)

      index = iter(range(6))
      def get_variable(constraint):
        if constraint is None:
          i = index.next()
          return inputs[:, i:i+1]
        else:
          return tf.fill(constant_shape, tf.constant(constraint,
                                                     dtype=inputs.dtype))

      constraints = chain.from_iterable(self.constraints)
      a, b, tx, c, d, ty = (get_variable(constr) for constr in constraints)

      det = a * d - b * c
      a_inv = d / det
      b_inv = -b / det
      c_inv = -c / det
      d_inv = a / det

      m_inv = basic.BatchReshape(
          [2, 2])(tf.concat([a_inv, b_inv, c_inv, d_inv], 1))

      txy = tf.expand_dims(tf.concat([tx, ty], 1), 2)

      txy_inv = basic.BatchFlatten()(tf.matmul(m_inv, txy))
      tx_inv = txy_inv[:, 0:1]
      ty_inv = txy_inv[:, 1:2]

      inverse_gw_inputs = tf.concat(
          [a_inv, b_inv, -tx_inv, c_inv, d_inv, -ty_inv], 1)

      agw = AffineGridWarper(self.output_shape,
                             self.source_shape)


      return agw(inverse_gw_inputs)  # pylint: disable=not-callable
Exemple #3
0
  def _build(self, inputs):
    """Assembles the module network and adds it to the graph.

    The internal computation graph is assembled according to the set of
    constraints provided at construction time.

    Args:
      inputs: Tensor containing a batch of transformation parameters.

    Returns:
      A batch of warped grids.

    Raises:
      Error: If the input tensor size is not consistent with the constraints
        passed at construction time.
    """
    input_shape = tf.shape(inputs)
    input_dtype = inputs.dtype.as_numpy_dtype
    batch_size = tf.expand_dims(input_shape[0], 0)
    number_of_params = inputs.get_shape()[1]
    if number_of_params != self._constraints.num_free_params:
      raise base.Error('Input size is not consistent with constraint '
                       'definition: {} parameters expected, {} provided.'
                       .format(self._constraints.num_free_params,
                               number_of_params))
    num_output_dimensions = len(self._psi) // 3
    def get_input_slice(start, size):
      """Extracts a subset of columns from the input 2D Tensor."""
      return basic.SliceByDim([1], [start], [size])(inputs)

    warped_grid = []
    var_index_offset = 0
    number_of_points = np.prod(self._output_shape)
    for i in xrange(num_output_dimensions):
      if self._psi[i] is not None:
        # The i-th output dimension is not fully specified by the constraints,
        # the graph is setup to perform matrix multiplication in batch mode.
        grid_coord = self._psi[i].astype(input_dtype)

        num_active_vars = self._psi[i].shape[0]
        active_vars = get_input_slice(var_index_offset, num_active_vars)
        warped_coord = tf.matmul(active_vars, grid_coord)
        warped_coord = tf.expand_dims(warped_coord, 1)
        var_index_offset += num_active_vars
        offset = self._psi[num_output_dimensions + i]
        if offset is not None:
          offset = offset.astype(input_dtype)
          # Some entries in the i-th row of the affine matrix were constrained
          # and the corresponding matrix multiplications have been precomputed.
          tiling_params = tf.concat(
              [
                  batch_size, tf.constant(
                      1, shape=(1,)), tf.ones_like(offset.shape)
              ],
              0)
          offset = offset.reshape((1, 1) + offset.shape)
          warped_coord += tf.tile(offset, tiling_params)

      else:
        # The i-th output dimension is fully specified by the constraints, and
        # the corresponding matrix multiplications have been precomputed.
        warped_coord = self._psi[num_output_dimensions + i].astype(input_dtype)
        tiling_params = tf.concat(
            [
                batch_size, tf.constant(
                    1, shape=(1,)), tf.ones_like(warped_coord.shape)
            ],
            0)
        warped_coord = warped_coord.reshape((1, 1) + warped_coord.shape)
        warped_coord = tf.tile(warped_coord, tiling_params)

      warped_coord += self._psi[i + 2 * num_output_dimensions]
      # Need to help TF figuring out shape inference since tiling information
      # is held in Tensors which are not known until run time.
      warped_coord.set_shape([None, 1, number_of_points])
      warped_grid.append(warped_coord)

    # Reshape all the warped coordinates tensors to match the specified output
    # shape and concatenate  into a single matrix.
    grid_shape = self._output_shape + (1,)
    warped_grid = [basic.BatchReshape(grid_shape)(grid) for grid in warped_grid]
    return tf.concat(warped_grid, len(grid_shape))
    def _build(self,
               inputs,
               query_inputs=None,
               state=None,
               is_training=False,
               dropout_keep_prob=0.5,
               key_value_inputs=None):
        """Calculates multi-layer self attention.

    Args:
      inputs: Tensor of shape [batch_size, num_steps, output_dim_size]. Inputs
        used as the query, key, and value to the attention layer.
      query_inputs: optional Tensor of shape [batch_size, num_steps,
        output_dim_size]. Query inputs to the attention layer. Set when
        query_inputs is different from the inputs argument.
      state: optional CompressedMemoryState or a Tensor of shape [batch_size,
        memory_size, dim_size] concatenated to the inputs. Set when attend to
        the memory from previous steps.
      is_training: if currently training.
      dropout_keep_prob: dropout rate applied to attention weights.
      key_value_inputs: optional Tensor of shape [batch_size, num_steps,
        output_dim_size]. It is used as the key and value of the multihead
        attention. Set when the key and value are different from the inputs
        argument.

    Returns:
      output: the result Tensor of shape
        [batch_size, num_steps, output_dim_size].
      attention_state: named tuple of AttentionState.
    """
        if key_value_inputs is not None and state is not None:
            raise ValueError(
                'Only one of the key_value_input and state is needed.')
        embedding_size = self._value_size * self._num_heads

        q_inputs = inputs if query_inputs is None else query_inputs
        # Denoted by L. If query_inputs is None, L = N.
        _, query_size = q_inputs.get_shape().as_list()[:2]

        if key_value_inputs is not None:
            k_inputs = key_value_inputs
            v_inputs = k_inputs
        elif state is not None:
            if isinstance(state, CompressedMemoryState):
                state_memory_list = [
                    state.compressed_memory, state.episodic_memory
                ]
            else:
                state_memory_list = [state]

            k_inputs = tf.concat(state_memory_list + [inputs], 1)
            v_inputs = k_inputs
        else:
            k_inputs = inputs
            v_inputs = inputs

        # Batch size denoted by B
        batch_size = tf.shape(inputs)[0]
        # Chunk_size denoted by N
        chunk_size = inputs.get_shape().as_list()[1]
        # Denoted by N + M
        att_size = k_inputs.get_shape().as_list()[1]

        if self._positional_encodings and not self._use_relative_positions:
            if len(self._positional_encodings) != 1:
                raise ValueError(
                    'Absolute positional encodings only supported for 1 memory. '
                    'Found %i.' % len(self._positional_encodings))
            key_positions, query_positions = self._positional_encodings[0]
            k_inputs += key_positions
            q_inputs += query_positions

        # [B, H, L, K]
        q = self.multihead_linear(q_inputs, 'query')
        # [B, H, N + M, K]
        k = self.multihead_linear(k_inputs, 'key')
        # [B, H, N + M, V]
        v = self.multihead_linear(v_inputs, 'value')

        # Scaling the dot-product
        if self._scaling:
            q *= self._key_size**-0.5

        # [B, H, L, N + M]
        if self._use_relative_positions:
            r_w_bias = tf.get_variable('r_w_bias',
                                       [1, self._num_heads, 1, self._key_size],
                                       dtype=inputs.dtype)
            content_logits = tf.matmul(q + r_w_bias, k, transpose_b=True)
            all_relative_logits = []
            # Loop over multiple positional encodings, for the case of multiple
            # memory types.
            for i, positional_encodings in enumerate(
                    self._positional_encodings):
                key_positions, query_positions = positional_encodings
                if key_positions.get_shape().as_list()[-1] != att_size:
                    key_positions = key_positions[:,
                                                  -att_size:]  # Crop to layer mem size
                is_final = i == len(self._positional_encodings) - 1
                suffix = '' if is_final else '_%d' % i
                relative_keys = self.multihead_linear(key_positions,
                                                      name='relative_keys' +
                                                      suffix)
                # [B, H, N, D]
                r_r_bias = tf.get_variable(
                    'r_r_bias' + suffix,
                    [1, self._num_heads, 1, self._key_size],
                    dtype=inputs.dtype)
                relative_keys = tf.tile(relative_keys, [batch_size, 1, 1, 1])
                relative_logits = tf.matmul(q + r_r_bias,
                                            relative_keys,
                                            transpose_b=True)
                relative_logits = rel_shift(relative_logits)
                if not is_final:  # Include relative positions for input sequence.
                    relative_logits = relative_logits[:, :, :, :-chunk_size]
                all_relative_logits.append(relative_logits)
            all_relative_logits = tf.concat(all_relative_logits, 3)
            logits = content_logits + all_relative_logits
        else:
            # [B, H, N, N + M]
            logits = tf.matmul(q, k, transpose_b=True)
            content_logits = logits

        if self._mask is not None:
            if self._mask.get_shape().as_list()[-1] != att_size:
                mask = self._mask[:, :, :, -att_size:]
            else:
                mask = self._mask
            logits += mask

        weights = tf.nn.softmax(logits)
        if is_training:
            weights = tf.nn.dropout(weights, dropout_keep_prob)
        # [B, L, H, V], where V is value_size
        output_transpose = tf.einsum('bhij,bhjk->bihk', weights, v)

        # [B, L, H, V] -> [B, L, HV]
        attended_inputs = basic.BatchReshape([query_size, embedding_size
                                              ])(output_transpose)
        # Apply final mlp to mix information between heads.
        output = basic.BatchApply(
            basic.Linear(embedding_size))(attended_inputs)

        attention_state = AttentionState(queries=q,
                                         keys=k,
                                         values=v,
                                         weights=weights,
                                         logits=content_logits,
                                         embeddings=inputs,
                                         read_words=output)
        return output, attention_state
Exemple #5
0
    def _build(self,
               inputs,
               query_inputs=None,
               state=None,
               is_training=False,
               dropout_keep_prob=0.5):
        embedding_size = self._value_size * self._num_heads

        q_inputs = inputs if query_inputs is None else query_inputs
        # Denoted by L. If query_inputs is None, L = N.
        _, query_size = q_inputs.get_shape().as_list()[:2]

        if state is not None:
            if isinstance(state, CompressedMemoryState):
                state_memory_list = [
                    state.compressed_memory, state.episodic_memory
                ]
            else:
                state_memory_list = [state]

            k_inputs = tf.concat(state_memory_list + [inputs], 1)
            v_inputs = k_inputs
        else:
            k_inputs = inputs
            v_inputs = inputs

        # Batch size denoted by B
        batch_size = tf.shape(inputs)[0]
        # Chunk_size denoted by N
        chunk_size = inputs.get_shape().as_list()[1]
        # Denoted by N + M
        att_size = k_inputs.get_shape().as_list()[1]

        if self._positional_encodings and not self._use_relative_positions:
            key_positions, query_positions = self._positional_encodings
            k_inputs += key_positions
            q_inputs += query_positions

        # [B, H, L, K]
        q = self.multihead_linear(q_inputs, 'query')
        # [B, H, N + M, K]
        k = self.multihead_linear(k_inputs, 'key')
        # [B, H, N + M, V]
        v = self.multihead_linear(v_inputs, 'value')

        # Scaling the dot-product
        if self._scaling:
            q *= self._key_size**-0.5

        # [B, H, L, N + M]
        if self._use_relative_positions:
            r_w_bias = tf.get_variable('r_w_bias',
                                       [1, self._num_heads, 1, self._key_size],
                                       dtype=inputs.dtype)
            content_logits = tf.matmul(q + r_w_bias, k, transpose_b=True)
            all_relative_logits = []
            # Loop over multiple positional encodings, for the case of multiple
            # memory types.
            for i, positional_encodings in enumerate(
                    self._positional_encodings):
                key_positions, query_positions = positional_encodings
                if key_positions.get_shape().as_list()[-1] != att_size:
                    key_positions = key_positions[:,
                                                  -att_size:]  # Crop to layer mem size
                is_final = i == len(self._positional_encodings) - 1
                suffix = '' if is_final else '_%d' % i
                relative_keys = self.multihead_linear(key_positions,
                                                      name='relative_keys' +
                                                      suffix)
                # [B, H, N, D]
                r_r_bias = tf.get_variable(
                    'r_r_bias' + suffix,
                    [1, self._num_heads, 1, self._key_size],
                    dtype=inputs.dtype)
                relative_keys = tf.tile(relative_keys, [batch_size, 1, 1, 1])
                relative_logits = tf.matmul(q + r_r_bias,
                                            relative_keys,
                                            transpose_b=True)
                relative_logits = rel_shift(relative_logits)
                if not is_final:  # Include relative positions for input sequence.
                    relative_logits = relative_logits[:, :, :, :-chunk_size]
                all_relative_logits.append(relative_logits)
            all_relative_logits = tf.concat(all_relative_logits, 3)
            logits = content_logits + all_relative_logits
        else:
            # [B, H, N, N + M]
            logits = tf.matmul(q, k, transpose_b=True)
            content_logits = logits

        if self._mask is not None:
            if self._mask.get_shape().as_list()[-1] != att_size:
                mask = self._mask[:, :, :, -att_size:]
            else:
                mask = self._mask
            logits += mask

        weights = tf.nn.softmax(logits)
        if is_training:
            weights = tf.nn.dropout(weights, dropout_keep_prob)
        # [B, L, H, V], where V is value_size
        output_transpose = tf.einsum('bhij,bhjk->bihk', weights, v)

        # [B, L, H, V] -> [B, L, HV]
        attended_inputs = basic.BatchReshape([query_size, embedding_size
                                              ])(output_transpose)
        # Apply final mlp to mix information between heads.
        output = basic.BatchApply(
            basic.Linear(embedding_size))(attended_inputs)

        attention_state = AttentionState(queries=q,
                                         keys=k,
                                         values=v,
                                         weights=weights,
                                         logits=content_logits,
                                         embeddings=inputs,
                                         read_words=output)
        return output, attention_state