Example #1
0
def mlp(feature, hparams, name="mlp"):
  """Multi layer perceptron with dropout and relu activation."""
  with tf.variable_scope(name, "mlp", values=[feature]):
    num_mlp_layers = hparams.num_mlp_layers
    mlp_dim = hparams.mlp_dim
    for _ in range(num_mlp_layers):
      feature = common_layers.dense(feature, mlp_dim, activation=tf.nn.relu)
      feature = tf.nn.dropout(feature, keep_prob=1.-hparams.dropout)
    return feature
  def body(self, features):
    hp = self.hparams
    # pylint: disable=eval-used
    if hp.image_input_type == "image":
      image_feat = vqa_layers.image_embedding(
          features["inputs"],
          model_fn=eval(hp.image_model_fn),
          trainable=hp.train_resnet,
          is_training=hp.mode == tf.estimator.ModeKeys.TRAIN)
    else:
      image_feat = features["inputs"]

    image_feat = common_layers.flatten4d3d(image_feat)
    image_feat = common_layers.dense(image_feat, hp.hidden_size)
    utils.collect_named_outputs("norms", "image_feat_after_proj",
                                tf.norm(image_feat, axis=-1))

    question = common_layers.flatten4d3d(features["question"])
    utils.collect_named_outputs("norms", "question_embedding",
                                tf.norm(question, axis=-1))
    (encoder_input, encoder_self_attention_bias,
     encoder_decoder_attention_bias) = prepare_image_question_encoder(
         image_feat, question, hp)

    encoder_input = tf.nn.dropout(
        encoder_input, keep_prob=1.-hp.layer_prepostprocess_dropout)

    encoder_output, _ = recurrent_transformer_decoder(
        encoder_input, None, encoder_self_attention_bias, None,
        hp, name="encoder")
    utils.collect_named_outputs(
        "norms", "encoder_output", tf.norm(encoder_output, axis=-1))

    # scale query by sqrt(hidden_size)
    query = tf.get_variable("query", [hp.hidden_size]) * hp.hidden_size **0.5
    query = tf.expand_dims(tf.expand_dims(query, axis=0), axis=0)
    batch_size = common_layers.shape_list(encoder_input)[0]
    query = tf.tile(query, [batch_size, 1, 1])
    query = tf.nn.dropout(
        query, keep_prob=1.-hp.layer_prepostprocess_dropout)

    decoder_output, _ = recurrent_transformer_decoder(
        query, encoder_output, None, encoder_decoder_attention_bias,
        hp, name="decoder")
    utils.collect_named_outputs("norms", "decoder_output",
                                tf.norm(decoder_output, axis=-1))

    norm_tensors = utils.convert_collection_to_dict("norms")
    vqa_layers.summarize_tensors(norm_tensors, tag="norms/")

    # Expand dimension 1 and 2
    return tf.expand_dims(decoder_output, axis=1)
def mlp(feature, hparams, name="mlp"):
    """Multi layer perceptron with dropout and relu activation."""
    with tf.variable_scope(name, "mlp", values=[feature]):
        num_mlp_layers = hparams.num_mlp_layers
        mlp_size = hparams.mlp_size
        for _ in range(num_mlp_layers):
            feature = common_layers.dense(feature, mlp_size, activation=None)
            utils.collect_named_outputs("norms", "mlp_feature",
                                        tf.norm(feature, axis=-1))
            feature = common_layers.layer_norm(feature)
            feature = tf.nn.relu(feature)
            feature = tf.nn.dropout(feature, keep_prob=1. - hparams.dropout)
        return feature
def mlp(feature, hparams, name="mlp"):
  """Multi layer perceptron with dropout and relu activation."""
  with tf.variable_scope(name, "mlp", values=[feature]):
    num_mlp_layers = hparams.num_mlp_layers
    mlp_size = hparams.mlp_size
    for _ in range(num_mlp_layers):
      feature = common_layers.dense(feature, mlp_size, activation=None)
      utils.collect_named_outputs("norms", "mlp_feature",
                                  tf.norm(feature, axis=-1))
      feature = common_layers.layer_norm(feature)
      feature = tf.nn.relu(feature)
      feature = tf.nn.dropout(feature, keep_prob=1.-hparams.dropout)
    return feature
def _compute_edge_transforms(node_states,
                             depth,
                             num_transforms,
                             name="transform"):
  """Helper function that computes transformation for keys and values.

  Let B be the number of batches.
  Let N be the number of nodes in the graph.
  Let D be the size of the node hidden states.
  Let K be the size of the attention keys/queries (total_key_depth).
  Let V be the size of the attention values (total_value_depth).
  Let T be the total number of transforms (num_transforms).

  Computes the transforms for keys or values for attention.
  * For each node N_j and edge type t, a key K_jt of size K is computed. When an
    edge of type t goes from node N_j to any other node, K_jt is the key that is
    in the attention process.
  * For each node N_j and edge type t, a value V_jt of size V is computed. When
    an edge of type t goes from node N_j to node N_i, Attention(Q_i, K_jt)
    produces a weight w_ijt. The message sent along this edge is w_ijt * V_jt.

  Args:
    node_states: A tensor of shape [B, L, D]
    depth: An integer (K or V)
    num_transforms: An integer (T),
    name: A name for the function

  Returns:
    x: A The attention keys or values for each node and edge type
      (shape [B, N*T, K or V])
  """
  node_shapes = common_layers.shape_list(node_states)
  x = common_layers.dense(
      node_states,
      depth * num_transforms,
      use_bias=False,
      name=name)

  batch = node_shapes[0]  # B.
  length = node_shapes[1]  # N.

  # Making the fourth dimension explicit by separating the vectors of size
  # K*T (in k) and V*T (in v) into two-dimensional matrices with shape [K, T]
  # (in k) and [V, T] in v.
  #
  x = tf.reshape(x, [batch, length, num_transforms, depth])

  # Flatten out the fourth dimension.
  x = tf.reshape(x, [batch, length * num_transforms, depth])

  return x
def _compute_edge_transforms(node_states,
                             depth,
                             num_transforms,
                             name="transform"):
  """Helper function that computes transformation for keys and values.

  Let B be the number of batches.
  Let N be the number of nodes in the graph.
  Let D be the size of the node hidden states.
  Let K be the size of the attention keys/queries (total_key_depth).
  Let V be the size of the attention values (total_value_depth).
  Let T be the total number of transforms (num_transforms).

  Computes the transforms for keys or values for attention.
  * For each node N_j and edge type t, a key K_jt of size K is computed. When an
    edge of type t goes from node N_j to any other node, K_jt is the key that is
    in the attention process.
  * For each node N_j and edge type t, a value V_jt of size V is computed. When
    an edge of type t goes from node N_j to node N_i, Attention(Q_i, K_jt)
    produces a weight w_ijt. The message sent along this edge is w_ijt * V_jt.

  Args:
    node_states: A tensor of shape [B, L, D]
    depth: An integer (K or V)
    num_transforms: An integer (T),
    name: A name for the function

  Returns:
    x: A The attention keys or values for each node and edge type
      (shape [B, N*T, K or V])
  """
  node_shapes = common_layers.shape_list(node_states)
  x = common_layers.dense(
      node_states,
      depth * num_transforms,
      use_bias=False,
      name=name)

  batch = node_shapes[0]  # B.
  length = node_shapes[1]  # N.

  # Making the fourth dimension explicit by separating the vectors of size
  # K*T (in k) and V*T (in v) into two-dimensional matrices with shape [K, T]
  # (in k) and [V, T] in v.
  #
  x = tf.reshape(x, [batch, length, num_transforms, depth])

  # Flatten out the fourth dimension.
  x = tf.reshape(x, [batch, length * num_transforms, depth])

  return x
Example #7
0
    def body(self, features):
        hp = self.hparams
        # pylint: disable=eval-used
        if hp.image_input_type == "image":
            image_feat = vqa_layers.image_embedding(
                features["inputs"],
                model_fn=eval(hp.image_model_fn),
                trainable=hp.train_resnet,
                is_training=hp.mode == tf.estimator.ModeKeys.TRAIN)
        else:
            image_feat = features["inputs"]

        image_feat = common_layers.flatten4d3d(image_feat)
        image_model_d = hp.model_d
        image_feat = common_layers.dense(image_feat, image_model_d)
        utils.collect_named_outputs("norms", "image_feat_after_proj",
                                    tf.norm(image_feat, axis=-1))

        question = common_layers.flatten4d3d(features["question"])
        utils.collect_named_outputs("norms", "question_embedding",
                                    tf.norm(question, axis=-1))
        (encoder_input, encoder_self_attention_bias,
         encoder_decoder_attention_bias) = prepare_image_question_encoder(
             image_feat, question, hp)
        encoder_input = tf.nn.dropout(encoder_input,
                                      keep_prob=1. -
                                      hp.layer_prepostprocess_dropout)
        encoder_output = image_question_encoder(encoder_input,
                                                encoder_self_attention_bias,
                                                hp)
        utils.collect_named_outputs("norms", "encoder_output",
                                    tf.norm(encoder_output, axis=-1))

        # scale query by sqrt(model_d)
        query = tf.get_variable("query", [hp.model_d]) * hp.model_d**0.5
        query = tf.expand_dims(tf.expand_dims(query, axis=0), axis=0)
        batch_size = common_layers.shape_list(encoder_input)[0]
        query = tf.tile(query, [batch_size, 1, 1])
        query = tf.nn.dropout(query,
                              keep_prob=1. - hp.layer_prepostprocess_dropout)

        decoder_output = decoder(query, encoder_output, None,
                                 encoder_decoder_attention_bias, hp)
        utils.collect_named_outputs("norms", "decoder_output",
                                    tf.norm(decoder_output, axis=-1))

        norm_tensors = utils.convert_collection_to_dict("norms")
        vqa_layers.summarize_tensors(norm_tensors, tag="norms/")

        # Expand dimension 1 and 2
        return tf.expand_dims(decoder_output, axis=1)
def BilinearAttention(c, q, c_mask, filters, name, norm=False):
    q = tf.expand_dims(dense(q,
                             filters,
                             name=name,
                             reuse=tf.AUTO_REUSE,
                             kernel_initializer=initializer(),
                             kernel_regularizer=regularizer),
                       axis=-1)  # [bs, dim, 1]
    if norm:
        q = layer_norm(q)
    cq = tf.squeeze(tf.matmul(c, q),
                    axis=-1)  # [bs, c_len, dim] * [bs, dim, 1] -> [bs, c_len]
    cq = exp_mask(cq, c_mask)
    return cq
Example #9
0
def universal_transformer_all_steps_so_far(layer_inputs, step, hparams,
                                           ffn_unit, attention_unit):
    """universal_transformer.
  It uses an attention mechanism-flipped vertically-
  over all the states from previous steps to generate the new_state.
  Args:
    layer_inputs:
      - state: state
      - memory: contains states from all the previous steps.
    step: indicating number of steps take so far
    hparams: model hyper-parameters.
    ffn_unit: feed-forward unit
    attention_unit: multi-head attention unit
  Returns:
    layer_output:
        new_state: new state
        memory: contains states from all the previous steps.
  """
    _, inputs, memory = layer_inputs
    all_states = memory
    # get the states up to the current step (non-zero part of the memory)
    states_so_far = all_states[:step, :, :, :]

    states_so_far_weights = tf.nn.softmax(common_layers.dense(
        states_so_far, (hparams.hidden_size if hparams.dwa_elements else 1),
        activation=None,
        use_bias=True),
                                          axis=-1)

    # # get summary of the step weights
    # step_weightes = tf.unstack(states_so_far_weights, axis=0, name="step_weightes")
    # for step_i, step_w  in enumerate(step_weightes):
    #   tf.contrib.summary.scalar("step_%d_weight:"%step_i,
    #                             tf.reduce_mean(step_w))

    # prepare the state as the summary of
    state_to_be_transformed = tf.reduce_sum(
        (states_so_far * states_so_far_weights), axis=0)

    state_to_be_transformed = universal_transformer_util.step_preprocess(
        state_to_be_transformed, step, hparams)

    new_state = ffn_unit(attention_unit(state_to_be_transformed))

    # add the new state to the memory
    memory = universal_transformer_util.fill_memory_slot(
        memory, new_state, step + 1)

    return new_state, inputs, memory
  def body(self, features):
    hp = self.hparams
    model_fn = resnet_v1_152
    if hp.image_model_fn != "resnet_v1_152":
      model_fn = eval(hp.image_model_fn)  # pylint: disable=eval-used
    if hp.image_input_type == "image":
      image_feat = vqa_layers.image_embedding(
          features["inputs"],
          model_fn=model_fn,
          trainable=hp.train_resnet,
          is_training=hp.mode == tf.estimator.ModeKeys.TRAIN)
    else:
      image_feat = features["inputs"]

    if hp.image_feat_size:
      image_feat = common_layers.dense(image_feat, hp.image_feat_size)

    # apply layer normalization and dropout on image_feature
    utils.collect_named_outputs("norms", "image_feat_before_l2",
                                tf.norm(image_feat, axis=-1))
    image_feat = common_layers.l2_norm(image_feat)
    utils.collect_named_outputs("norms", "image_feat_after_l2",
                                tf.norm(image_feat, axis=-1))

    image_feat = tf.nn.dropout(image_feat, keep_prob=1.-hp.dropout)

    query = question_encoder(features["question"], hp)
    utils.collect_named_outputs("norms", "query",
                                tf.norm(query, axis=-1))

    image_ave = attn(image_feat, query, hp)
    utils.collect_named_outputs("norms", "image_ave",
                                tf.norm(image_ave, axis=-1))

    image_question = tf.concat([image_ave, query], axis=1)
    utils.collect_named_outputs("norms", "image_question",
                                tf.norm(image_question, axis=-1))

    image_question = tf.nn.dropout(image_question, 1. - hp.dropout)

    output = mlp(image_question, hp)
    utils.collect_named_outputs("norms", "output",
                                tf.norm(output, axis=-1))

    norm_tensors = utils.convert_collection_to_dict("norms")
    vqa_layers.summarize_tensors(norm_tensors, tag="norms/")

    # Expand dimension 1 and 2
    return tf.expand_dims(tf.expand_dims(output, axis=1), axis=2)
Example #11
0
  def body(self, features):
    hp = self.hparams
    # pylint: disable=eval-used
    if hp.image_input_type == "image":
      image_feat = vqa_layers.image_embedding(
          features["inputs"],
          model_fn=eval(hp.image_model_fn),
          trainable=hp.train_resnet,
          is_training=hp.mode == tf.estimator.ModeKeys.TRAIN)
    else:
      image_feat = features["inputs"]

    if hp.image_feat_size:
      image_feat = common_layers.dense(image_feat, hp.image_feat_size)

    # apply layer normalization and dropout on image_feature
    utils.collect_named_outputs("norms", "image_feat_before_l2",
                                tf.norm(image_feat, axis=-1))
    image_feat = common_layers.l2_norm(image_feat)
    utils.collect_named_outputs("norms", "image_feat_after_l2",
                                tf.norm(image_feat, axis=-1))

    image_feat = tf.nn.dropout(image_feat, keep_prob=1.-hp.dropout)

    query = question_encoder(features["question"], hp)
    utils.collect_named_outputs("norms", "query",
                                tf.norm(query, axis=-1))

    image_ave = attn(image_feat, query, hp)
    utils.collect_named_outputs("norms", "image_ave",
                                tf.norm(image_ave, axis=-1))

    image_question = tf.concat([image_ave, query], axis=1)
    utils.collect_named_outputs("norms", "image_question",
                                tf.norm(image_question, axis=-1))

    image_question = tf.nn.dropout(image_question, 1. - hp.dropout)

    output = mlp(image_question, hp)
    utils.collect_named_outputs("norms", "output",
                                tf.norm(output, axis=-1))

    norm_tensors = utils.convert_collection_to_dict("norms")
    vqa_layers.summarize_tensors(norm_tensors, tag="norms/")

    # Expand dimension 1 and 2
    return tf.expand_dims(tf.expand_dims(output, axis=1), axis=2)
    def body(self, features):
        target_seq = features["targets_raw"]
        target_seq = tf.identity(target_seq, "target_seq")
        hp = self._hparams

        actions_seq, grid_output_top_list = self.forward_path(features)
        # argmaxed_grid_output = tf.identity(tf.argmax(grid_output_top_list, axis=-1), "argmaxed_grid_output")

        # Match the output to the target modality (shape, dimension)
        output_vocab_size = self._problem_hparams.vocab_size["targets"]
        logits = common_layers.dense(grid_output_top_list, output_vocab_size)

        # Compute loss
        self.compute_loss(target_seq, logits)

        # Flip the logits for evaluation during training
        final = infer_flipped_outputs(logits, hp.list_size)
        return {"targets": final}, self._additional_loss
def _compute_edge_transforms(node_states,
                             depth,
                             num_transforms,
                             ignore_zero=True,
                             name="transform"):
    """Helper function that computes transformation for keys and values.

  Let B be the number of batches.
  Let N be the number of nodes in the graph.
  Let D be the size of the node hidden states.
  Let K be the size of the attention keys/queries (total_key_depth).
  Let V be the size of the attention values (total_value_depth).
  Let T be the total number of transforms (num_transforms).

  Computes the transforms for keys or values for attention.
  * For each node N_j and edge type t, a key K_jt of size K is computed. When an
    edge of type t goes from node N_j to any other node, K_jt is the key that is
    in the attention process.
  * For each node N_j and edge type t, a value V_jt of size V is computed. When
    an edge of type t goes from node N_j to node N_i, Attention(Q_i, K_jt)
    produces a weight w_ijt. The message sent along this edge is w_ijt * V_jt.

  Args:
    node_states: A tensor of shape [B, L, D]
    depth: An integer (K or V)
    num_transforms: An integer (T),
    ignore_zero: A boolean to ignore 0 edge
    name: A name for the function

  Returns:
    x: A The attention keys or values for each node and edge type
      (shape [B, N*T, K or V])
  """
    node_shapes = common_layers.shape_list(node_states)
    nonignored_transforms = num_transforms - int(ignore_zero)
    x = common_layers.dense(node_states,
                            depth * nonignored_transforms,
                            use_bias=False,
                            name=name)

    batch = node_shapes[0]  # B.
    length = node_shapes[1]  # N.

    # Making the fourth dimension explicit by separating the vectors of size
    # K*T (in k) and V*T (in v) into two-dimensional matrices with shape [K, T]
    # (in k) and [V, T] in v.
    #
    # This reshape is only necessary when ignore_zero is True (for the padding
    # step that follows).
    x = tf.reshape(x, [batch, length, nonignored_transforms, depth])

    # If we previously ignored edge type 0, then we need to pad the keys and
    # values to take this additional edge type into account. To do so, we
    # pad the third dimension of k and v (which has size T-1 if ignore_zero is
    # True) to size T with zeroes.
    if ignore_zero:
        x = tf.pad(x, [[0, 0], [0, 0], [1, 0], [0, 0]])

    # Flatten out the fourth dimension.
    x = tf.reshape(x, [batch, length * num_transforms, depth])

    return x
Example #14
0
def sparse_message_pass(node_states,
                        adjacency_matrices,
                        num_edge_types,
                        hidden_size,
                        use_bias=True,
                        average_aggregation=False,
                        name="sparse_ggnn"):
    """One message-passing step for a GNN with a sparse adjacency matrix.

  Implements equation 2 (the message passing step) in
  [Li et al. 2015](https://arxiv.org/abs/1511.05493).

  N = The number of nodes in each batch.
  H = The size of the hidden states.
  T = The number of edge types.

  Args:
    node_states: Initial states of each node in the graph. Shape is [N, H].
    adjacency_matrices: Adjacency matrix of directed edges for each edge
      type. Shape is [N, N, T] (sparse tensor).
    num_edge_types: The number of edge types. T.
    hidden_size: The size of the hidden state. H.
    use_bias: Whether to use bias in the hidden layer.
    average_aggregation: How to aggregate the incoming node messages. If
      average_aggregation is true, the messages are averaged. If it is false,
      they are summed.
    name: (optional) The scope within which tf variables should be created.

  Returns:
    The result of one step of Gated Graph Neural Network (GGNN) message passing.
    Shape: [N, H]
  """
    n = tf.shape(node_states)[0]
    t = num_edge_types
    incoming_edges_per_type = tf.sparse_reduce_sum(adjacency_matrices, axis=1)

    # Convert the adjacency matrix into shape [T, N, N] - one [N, N] adjacency
    # matrix for each edge type. Since sparse tensor multiplication only supports
    # two-dimensional tensors, we actually convert the adjacency matrix into a
    # [T * N, N] tensor.
    adjacency_matrices = tf.sparse_transpose(adjacency_matrices, [2, 0, 1])
    adjacency_matrices = tf.sparse_reshape(adjacency_matrices, [t * n, n])

    # Multiply the adjacency matrix by the node states, producing a [T * N, H]
    # tensor. For each (edge type, node) pair, this tensor stores the sum of
    # the hidden states of the node's neighbors over incoming edges of that type.
    messages = tf.sparse_tensor_dense_matmul(adjacency_matrices, node_states)

    # Rearrange this tensor to have shape [N, T * H]. The incoming states of each
    # nodes neighbors are summed by edge type and then concatenated together into
    # a single T * H vector.
    messages = tf.reshape(messages, [t, n, hidden_size])
    messages = tf.transpose(messages, [1, 0, 2])
    messages = tf.reshape(messages, [n, t * hidden_size])

    # Run each of those T * H vectors through a linear layer that produces
    # a vector of size H. This process is equivalent to running each H-sized
    # vector through a separate linear layer for each edge type and then adding
    # the results together.
    #
    # Note that, earlier on, we added together all of the states of neighbors
    # that were connected by edges of the same edge type. Since addition and
    # multiplying by a linear layer are commutative, this process was equivalent
    # to running each incoming edge through a linear layer separately and then
    # adding everything at the end.
    with tf.variable_scope(name, default_name="sparse_ggnn"):
        final_node_states = common_layers.dense(messages,
                                                hidden_size,
                                                use_bias=False)

        # Multiply the bias by for each edge type by the number of incoming nodes
        # of that edge type.
        if use_bias:
            bias = tf.get_variable("bias",
                                   initializer=tf.zeros([t, hidden_size]))
            final_node_states += tf.matmul(incoming_edges_per_type, bias)

        if average_aggregation:
            incoming_edges = tf.reduce_sum(incoming_edges_per_type,
                                           -1,
                                           keepdims=True)
            incoming_edges = tf.tile(incoming_edges, [1, hidden_size])
            final_node_states /= incoming_edges + 1e-7

    return tf.reshape(final_node_states, [n, hidden_size])
Example #15
0
def multihead_mpnn_attention(node_states,
                             total_key_depth,
                             total_value_depth,
                             output_depth,
                             num_heads,
                             adjacency_matrix=None,
                             num_edge_types=5,
                             num_transforms=None,
                             use_weighted_sum=False,
                             name="mpnn_attention"):
    """Multihead scaled-dot-product attention with input/output transformations.

  Let B be the number of batches.
  Let N be the number of nodes in the graph.
  Let D be the size of the node hidden states.
  Let K be the size of the attention keys/queries (total_key_depth).
  Let V be the size of the attention values (total_value_depth).
  Let O be the size of the attention output (output_depth).
  Let H be the number of heads (num_heads).
  Let T be the total number of transforms (num_transforms).

  The key and value depths are split across all of the heads. For example, if
  the key depth is 6 and there are three heads, then the key for each head has
  depth 2.

  Args:
    node_states: A Tensor with shape [B, N, D]
    total_key_depth: An integer (K).
    total_value_depth: An integer (V).
    output_depth: An integer (O).
    num_heads: An integer (H).
    adjacency_matrix: An Tensor of ints with shape [B, T, N, N]. If there is an
      edge from node j to node i in batch b, then adjacency_matrix[b, i, j]
      contains the type of that edge as an integer. Otherwise, it contains 0.
    num_edge_types: An integer indicating number of edge types.
    num_transforms: An integer indicating number of transforms (T). If None,
      then num_transforms will be equal to num_edge_types.
    use_weighted_sum: If False, will only use a single transform per edge type.
      Otherwise, use a learned weighted sum of transforms per edge type.
    name: A string.

  Returns:
    The result of the attention transformation. The output shape is [B, N, O].

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
    if total_key_depth % num_heads != 0:
        raise ValueError("Key depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_key_depth, num_heads))
    if total_value_depth % num_heads != 0:
        raise ValueError("Value depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_value_depth, num_heads))
    with tf.variable_scope(name,
                           default_name="multihead_mpnn_attention",
                           values=[node_states]):
        # If not explicitly set, use num_transforms set to num_edge_types.
        num_transforms = (num_edge_types
                          if num_transforms is None else num_transforms)

        # Create the query for each node's incoming edges.
        # Create the keys/values for each node for each possible outgoing edge type.
        q, k, v = compute_mpnn_qkv(node_states, total_key_depth,
                                   total_value_depth, num_transforms)

        q_shape = tf.shape(q)  # As above, q_shape is [B, N, K].

        # Divides each query/key/value into separate heads. Specifically, the
        # query/key/value for each (batch, node) pair (i.e., the third dimensions
        # of q, k, and v) are broken into H separate pieces. These pieces are used
        # as the separate attention heads. The resulting tensors have shape
        # [B, H, N, ?/H], where ? = K, K*T or V*T as appropriate.
        q = common_attention.split_heads(q, num_heads)  # Shape [B, H, N, K/H].
        k = common_attention.split_heads(k,
                                         num_heads)  # Shape [B, H, N, K*T/H].
        v = common_attention.split_heads(v,
                                         num_heads)  # Shape [B, H, N, V*T/H].
        key_depth_per_head = total_key_depth // num_heads

        # Ensures that the logits don't have too large of a magnitude.
        q *= key_depth_per_head**-0.5

        # Rearrange the dimensions so that the head is first. This will make
        # subsequent steps easier (we loop over the head).
        q = tf.transpose(q, [1, 0, 2, 3])  # Shape [H, B, N, K/H].
        k = tf.transpose(k, [1, 0, 2, 3])  # Shape [H, B, N, K*T/H].
        v = tf.transpose(v, [1, 0, 2, 3])  # Shape [H, B, N, V*T/H].

        # Split the keys and values into separate per-edge-type keys and values.
        k = tf.reshape(k, [
            num_heads, q_shape[0], q_shape[1], num_transforms,
            total_key_depth // num_heads
        ])  # Shape [H, B, N, T, K/H].
        k = tf.transpose(k, [0, 1, 3, 2, 4])  # Shape [H, B, T, N, K/H].

        v = tf.reshape(v, [
            num_heads, q_shape[0], q_shape[1], num_transforms,
            total_value_depth // num_heads
        ])  # Shape [H, B, N, T, V/H].
        v = tf.transpose(v, [0, 1, 3, 2, 4])  # Shape [H, B, T, N, V/H].

        # Perform attention for each head and combine the results into a list.
        # head_outputs stores a list of tensors, each with shape [1, B, N, V/H].
        # The last dimension contains the values computed for each attention head.
        # Each value was determined by computing attention over all of the
        # incoming edges for node n, weighting the incoming values accordingly,
        # and adding those weighted values together.
        head_outputs = []
        for head_id in range(num_heads):
            output = dot_product_mpnn_attention(
                q[head_id],
                k[head_id],
                v[head_id],
                adjacency_matrix,
                num_edge_types,
                num_transforms=num_transforms,
                use_weighted_sum=use_weighted_sum)

            # Store this result in the list of attention results for each head.
            # The call to expand_dims gives output shape [1, B, N, V/H], which will
            # come in handy when we combine the heads together.
            head_outputs.append(tf.expand_dims(output, axis=0))

        # Combine the heads together into one tensor and rearrange the dimensions.
        x = tf.concat(head_outputs, axis=0)  # Shape [H, B, N, V/H].
        x = tf.transpose(x, [1, 0, 2, 3])  # Shape [B, H, N, V/H].

        # Concatenate the values produced by each head together into one vector.
        x = common_attention.combine_heads(x)  # Shape [B, N, V].

        # A fully-connected linear layer to convert from the value vectors of size V
        # to output vectors of length O (the appropriate output length).
        x = common_layers.dense(x,
                                output_depth,
                                use_bias=False,
                                name="output_transform")
        return x
Example #16
0
def multihead_graph_attention(query_antecedent,
                              memory_antecedent,
                              bias,
                              total_key_depth,
                              total_value_depth,
                              output_depth,
                              num_heads,
                              dropout_rate,
                              image_shapes=None,
                              attention_type="edge_vector",
                              name="multihead_graph_attention",
                              save_weights_to=None,
                              make_image_summary=True,
                              dropout_broadcast_dims=None,
                              adjacency_matrix=None,
                              num_edge_types=5,
                              vars_3d=False,
                              **kwargs):
    """Multihead scaled-dot-product attention with input/output transformations.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
    bias: bias Tensor (see attention_bias())
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    image_shapes: optional tuple of integer scalars.
                  see comments for attention_image_summary()
    attention_type: a string, either "dot_product", "dot_product_relative",
                    "local_mask_right", "local_unmasked", "masked_dilated_1d",
                    "unmasked_dilated_1d", graph, or any attention function
                    with the signature (query, key, value, **kwargs)
    name: an optional string.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    adjacency_matrix: an optional tensor of shape [batch, len_q, len_q]
      containing edge vectors for attention
    num_edge_types: number of edge types, an int
    vars_3d: use 3-dimensional variables for input/output transformations
    **kwargs (dict): Parameters for the attention function

  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, output_depth]

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
    if total_key_depth % num_heads != 0:
        raise ValueError("Key depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_key_depth, num_heads))
    if total_value_depth % num_heads != 0:
        raise ValueError("Value depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_value_depth, num_heads))
    vars_3d_num_heads = num_heads if vars_3d else None
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):

        q, k, v = common_attention.compute_qkv(
            query_antecedent,
            memory_antecedent,
            total_key_depth,
            total_value_depth,
            vars_3d_num_heads=vars_3d_num_heads)
        q = common_attention.split_heads(q, num_heads)
        k = common_attention.split_heads(k, num_heads)
        v = common_attention.split_heads(v, num_heads)

        key_depth_per_head = total_key_depth // num_heads
        if not vars_3d:
            q *= key_depth_per_head**-0.5

        additional_returned_value = None
        if callable(
                attention_type):  # Generic way to extend multihead_attention
            x = attention_type(q, k, v, **kwargs)
            if isinstance(x, tuple):
                x, additional_returned_value = x  # Unpack

        elif attention_type == "edge_vector":
            x = graph_attention(q,
                                k,
                                v,
                                bias,
                                dropout_rate,
                                image_shapes,
                                save_weights_to=save_weights_to,
                                make_image_summary=make_image_summary,
                                dropout_broadcast_dims=dropout_broadcast_dims,
                                adjacency_matrix=adjacency_matrix,
                                num_edge_types=num_edge_types)

        x = common_attention.combine_heads(x)

        # Set last dim specifically.
        x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])

        if vars_3d:
            o_var = tf.get_variable(
                "o", [num_heads, total_value_depth // num_heads, output_depth])
            o_var = tf.reshape(o_var, [total_value_depth, output_depth])
            x = tf.tensordot(x, o_var, axes=1)
        else:
            x = common_layers.dense(x,
                                    output_depth,
                                    use_bias=False,
                                    name="output_transform")
        if additional_returned_value is not None:
            return x, additional_returned_value
        return x
Example #17
0
def compute_mpnn_qkv(node_states, total_key_depth, total_value_depth,
                     num_transforms):
    """Computes query, key and value for edge matrices.

  Let B be the number of batches.
  Let N be the number of nodes in the graph.
  Let D be the size of the node hidden states.
  Let K be the size of the attention keys/queries (total_key_depth).
  Let V be the size of the attention values (total_value_depth).
  Let T be the total number of transforms (num_transforms).

  Computes the queries, keys, and values for attention.
  * For each node N_i in the graph, a query Q_i of size K is computed. This
    query is used to determine the relative weights to give to each of the
    node's incoming edges.
  * For each node N_j and edge type t, a key K_jt of size K is computed. When an
    edge of type t goes from node N_j to any other node, K_jt is the key that is
    in the attention process.
  * For each node N_j and edge type t, a value V_jt of size V is computed. When
    an edge of type t goes from node N_j to node N_i, Attention(Q_i, K_jt)
    produces a weight w_ijt. The message sent along this edge is w_ijt * V_jt.

  Args:
    node_states: A Tensor with shape [B, N, D].
    total_key_depth: an integer (K).
    total_value_depth: an integer (V).
    num_transforms: a integer specifying number of transforms (T). This is
      typically the number of edge types.
  Returns:
    q: The attention queries for each destination node (shape [B, N, K]).
    k: The attention keys for each node and edge type (shape [B, N*T, K]).
    v: The attention values for each node and edge type (shape [B, N*T, V]).
  """

    # node_states is initially a tensor with shape [B, N, D]. The call to dense
    # creates a D x K kernel that serves as a fully-connected layer.
    #
    # For each possible batch b and node n in the first two dimensions of
    # node_states, the corresponding size-D vector (the third dimension of
    # node_states) is the hidden state for node n in batch b. Each of these size-D
    # vectors is multiplied by the kernel to produce an attention query of size K.
    # The result is a tensor of size [B, N, K] containing the attention queries
    # for each node in each batch.
    q = common_layers.dense(node_states,
                            total_key_depth,
                            use_bias=False,
                            name="q_mpnn")

    # Creates the attention keys in a manner similar to the process of creating
    # the attention queries. One key is created for each type of outgoing edge the
    # corresponding node might have, meaning k will have shape [B, N, K*T].
    k = _compute_edge_transforms(node_states,
                                 total_key_depth,
                                 num_transforms,
                                 name="k_mpnn")
    v = _compute_edge_transforms(node_states,
                                 total_value_depth,
                                 num_transforms,
                                 name="v_mpnn")

    return q, k, v
def sparse_message_pass(node_states,
                        adjacency_matrices,
                        num_edge_types,
                        hidden_size,
                        use_bias=True,
                        average_aggregation=False,
                        name="sparse_ggnn"):
  """One message-passing step for a GNN with a sparse adjacency matrix.

  Implements equation 2 (the message passing step) in
  [Li et al. 2015](https://arxiv.org/abs/1511.05493).

  N = The number of nodes in each batch.
  H = The size of the hidden states.
  T = The number of edge types.

  Args:
    node_states: Initial states of each node in the graph. Shape is [N, H].
    adjacency_matrices: Adjacency matrix of directed edges for each edge
      type. Shape is [N, N, T] (sparse tensor).
    num_edge_types: The number of edge types. T.
    hidden_size: The size of the hidden state. H.
    use_bias: Whether to use bias in the hidden layer.
    average_aggregation: How to aggregate the incoming node messages. If
      average_aggregation is true, the messages are averaged. If it is false,
      they are summed.
    name: (optional) The scope within which tf variables should be created.

  Returns:
    The result of one step of Gated Graph Neural Network (GGNN) message passing.
    Shape: [N, H]
  """
  n = tf.shape(node_states)[0]
  t = num_edge_types
  incoming_edges_per_type = tf.sparse_reduce_sum(adjacency_matrices, axis=1)

  # Convert the adjacency matrix into shape [T, N, N] - one [N, N] adjacency
  # matrix for each edge type. Since sparse tensor multiplication only supports
  # two-dimensional tensors, we actually convert the adjacency matrix into a
  # [T * N, N] tensor.
  adjacency_matrices = tf.sparse_transpose(adjacency_matrices, [2, 0, 1])
  adjacency_matrices = tf.sparse_reshape(adjacency_matrices, [t * n, n])

  # Multiply the adjacency matrix by the node states, producing a [T * N, H]
  # tensor. For each (edge type, node) pair, this tensor stores the sum of
  # the hidden states of the node's neighbors over incoming edges of that type.
  messages = tf.sparse_tensor_dense_matmul(adjacency_matrices, node_states)

  # Rearrange this tensor to have shape [N, T * H]. The incoming states of each
  # nodes neighbors are summed by edge type and then concatenated together into
  # a single T * H vector.
  messages = tf.reshape(messages, [t, n, hidden_size])
  messages = tf.transpose(messages, [1, 0, 2])
  messages = tf.reshape(messages, [n, t * hidden_size])

  # Run each of those T * H vectors through a linear layer that produces
  # a vector of size H. This process is equivalent to running each H-sized
  # vector through a separate linear layer for each edge type and then adding
  # the results together.
  #
  # Note that, earlier on, we added together all of the states of neighbors
  # that were connected by edges of the same edge type. Since addition and
  # multiplying by a linear layer are commutative, this process was equivalent
  # to running each incoming edge through a linear layer separately and then
  # adding everything at the end.
  with tf.variable_scope(name, default_name="sparse_ggnn"):
    final_node_states = common_layers.dense(
        messages, hidden_size, use_bias=False)

    # Multiply the bias by for each edge type by the number of incoming nodes
    # of that edge type.
    if use_bias:
      bias = tf.get_variable("bias", initializer=tf.zeros([t, hidden_size]))
      final_node_states += tf.matmul(incoming_edges_per_type, bias)

    if average_aggregation:
      incoming_edges = tf.reduce_sum(incoming_edges_per_type, -1, keepdims=True)
      incoming_edges = tf.tile(incoming_edges, [1, hidden_size])
      final_node_states /= incoming_edges + 1e-7

  return final_node_states
Example #19
0
    def body(self, features):
        hp = self.hparams
        # pylint: disable=eval-used
        if hp.image_input_type == "image":
            image_feat = vqa_layers.image_embedding(
                features["inputs"],
                model_fn=eval(hp.image_model_fn),
                trainable=hp.train_resnet,
                is_training=hp.mode == tf.estimator.ModeKeys.TRAIN)
        else:
            image_feat = features["inputs"]

        image_feat = common_layers.flatten4d3d(image_feat)
        image_hidden_size = hp.image_hidden_size or hp.hidden_size
        if hp.image_feat_preprocess_proj:
            image_feat = common_layers.dense(image_feat, image_hidden_size)
            utils.collect_named_outputs("norms", "image_feat_after_proj",
                                        tf.norm(image_feat, axis=-1))
        else:
            assert image_hidden_size == 2048

        image_feat = tf.nn.dropout(image_feat,
                                   keep_prob=1. -
                                   hp.layer_prepostprocess_dropout)

        if hp.image_feat_encode:
            image_feat = image_encoder(image_feat, hp)
            utils.collect_named_outputs("norms", "image_feat_encoded",
                                        tf.norm(image_feat, axis=-1))
        else:
            image_feat = common_layers.layer_norm(image_feat)
            utils.collect_named_outputs("norms", "image_feat_after_layer",
                                        tf.norm(image_feat, axis=-1))

        question = common_layers.flatten4d3d(features["question"])
        utils.collect_named_outputs("norms", "question_embedding",
                                    tf.norm(question, axis=-1))
        question, question_self_attention_bias = prepare_question_encoder(
            question, hp)
        question = tf.nn.dropout(question,
                                 keep_prob=1. -
                                 hp.layer_prepostprocess_dropout)
        query = question_encoder(question, question_self_attention_bias, hp)
        utils.collect_named_outputs("norms", "query_encode",
                                    tf.norm(query, axis=-1))
        query = (query + tf.expand_dims(
            tf.squeeze(question_self_attention_bias, [1, 2]), axis=2))
        query = tf.reduce_max(query, axis=1)
        utils.collect_named_outputs("norms", "query_maxpool",
                                    tf.norm(query, axis=-1))

        # query = common_layers.l2_norm(query)
        # utils.collect_named_outputs("norms", "query_after_l2",
        #                             tf.norm(query, axis=-1))

        image_ave = attn(image_feat, query, hp)
        utils.collect_named_outputs("norms", "image_ave",
                                    tf.norm(image_ave, axis=-1))

        if hp.multimodal_combine == "concat":
            image_question = tf.concat([image_ave, query], axis=1)
        elif hp.multimodal_combine == "sum":
            image_question = image_ave + query
        elif hp.multimodal_combine == "product":
            image_question = image_ave * query

        utils.collect_named_outputs("norms", "image_question",
                                    tf.norm(image_question, axis=-1))

        image_question = tf.nn.dropout(image_question, 1. - hp.dropout)

        output = mlp(image_question, hp)
        utils.collect_named_outputs("norms", "output", tf.norm(output,
                                                               axis=-1))

        norm_tensors = utils.convert_collection_to_dict("norms")
        vqa_layers.summarize_tensors(norm_tensors, tag="norms/")

        # Expand dimension 1 and 2
        return tf.expand_dims(tf.expand_dims(output, axis=1), axis=2)
Example #20
0
def multihead_attention(query_antecedent,
                        memory_antecedent,
                        bias,
                        total_key_depth,
                        total_value_depth,
                        output_depth,
                        num_heads,
                        dropout_rate,
                        shared_rel=False,
                        max_relative_position=None,
                        image_shapes=None,
                        attention_type="dot_product",
                        block_length=128,
                        block_width=128,
                        q_filter_width=1,
                        kv_filter_width=1,
                        q_padding="VALID",
                        kv_padding="VALID",
                        cache=None,
                        gap_size=0,
                        num_memory_blocks=2,
                        name="multihead_attention",
                        save_weights_to=None,
                        make_image_summary=True,
                        dropout_broadcast_dims=None,
                        max_length=None,
                        vars_3d=False,
                        scale_dotproduct=True,
                        **kwargs):
  """Multihead scaled-dot-product attention with input/output transformations.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
    bias: bias Tensor (see attention_bias())
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    shared_rel: boolean to share relative embeddings
    max_relative_position: Maximum distance between inputs to generate
                           unique relation embeddings for. Only relevant
                           when using "dot_product_relative" attention.
    image_shapes: optional tuple of integer scalars.
                  see comments for attention_image_summary()
    attention_type: a string, either "dot_product", "dot_product_relative",
                    "local_mask_right", "local_unmasked", "masked_dilated_1d",
                    "unmasked_dilated_1d", graph, or any attention function
                    with the signature (query, key, value, **kwargs)
    block_length: an integer - relevant for "local_mask_right"
    block_width: an integer - relevant for "local_unmasked"
    q_filter_width: An integer specifying how wide you want the query to be.
    kv_filter_width: An integer specifying how wide you want the keys and values
                     to be.
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
               kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
               no padding.
    cache: dict containing Tensors which are the results of previous
           attentions, used for fast decoding. Expects the dict to contrain two
           keys ('k' and 'v'), for the initial call the values for these keys
           should be empty Tensors of the appropriate shape.
               'k' [batch_size, 0, key_channels]
               'v' [batch_size, 0, value_channels]
    gap_size: Integer option for dilated attention to indicate spacing between
              memory blocks.
    num_memory_blocks: Integer option to indicate how many memory blocks to look
                       at.
    name: an optional string.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    max_length: an integer - needed by relative attention
    vars_3d: use 3-dimensional variables for input/output transformations
    scale_dotproduct: whether to normalize the attention product.
    **kwargs (dict): Parameters for the attention function

  Caching:
    WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
    the caching assumes that the bias contains future masking.

    The caching works by saving all the previous key and value values so that
    you are able to send just the last query location to this attention
    function. I.e. if the cache dict is provided it assumes the query is of the
    shape [batch_size, 1, hidden_dim] rather than the full memory.

  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, hidden_dim]
    unless the cache dict is provided in which case only the last memory
    position is calculated and the output shape is [batch_size, 1, hidden_dim]
    Optionally returns an additional loss parameters (ex: load balance loss for
    the experts) returned by the attention_type function.

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
  if total_key_depth % num_heads != 0:
    raise ValueError("Key depth (%d) must be divisible by the number of "
                     "attention heads (%d)." % (total_key_depth, num_heads))
  if total_value_depth % num_heads != 0:
    raise ValueError("Value depth (%d) must be divisible by the number of "
                     "attention heads (%d)." % (total_value_depth, num_heads))
  vars_3d_num_heads = num_heads if vars_3d else 0
  with tf.variable_scope(name, default_name="multihead_attention",
                         values=[query_antecedent, memory_antecedent]):

    if cache is None or memory_antecedent is None:
      q, k, v = common_attention.compute_qkv(
          query_antecedent, memory_antecedent,
          total_key_depth, total_value_depth, q_filter_width,
          kv_filter_width, q_padding, kv_padding,
          vars_3d_num_heads=vars_3d_num_heads)
    if cache is not None:
      if attention_type != "dot_product":
        # TODO(petershaw): Support caching when using relative position
        # representations, i.e. "dot_product_relative" attention.
        raise NotImplementedError(
            "Caching is not guaranteed to work with attention types other than"
            " dot_product.")
      if bias is None:
        raise ValueError("Bias required for caching. See function docstring "
                         "for details.")

      if memory_antecedent is not None:
        # Encoder-Decoder Attention Cache
        q = common_attention.compute_attention_component(
            query_antecedent, total_key_depth,
            q_filter_width, q_padding, "q",
            vars_3d_num_heads=vars_3d_num_heads)
        k = cache["k_encdec"]
        v = cache["v_encdec"]
      else:
        k = common_attention.split_heads(k, num_heads)
        v = common_attention.split_heads(v, num_heads)
        decode_loop_step = kwargs.get("decode_loop_step")
        if decode_loop_step is None:
          k = cache["k"] = tf.concat([cache["k"], k], axis=2)
          v = cache["v"] = tf.concat([cache["v"], v], axis=2)
        else:
          # Inplace update is required for inference on TPU.
          # Inplace_ops only supports inplace_update on the first dimension.
          # The performance of current implementation is better than updating
          # the tensor by adding the result of matmul(one_hot,
          # update_in_current_step)
          tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3])
          tmp_k = inplace_ops.alias_inplace_update(
              tmp_k, decode_loop_step, tf.squeeze(k, axis=2))
          k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3])
          tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3])
          tmp_v = inplace_ops.alias_inplace_update(
              tmp_v, decode_loop_step, tf.squeeze(v, axis=2))
          v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3])

    q = common_attention.split_heads(q, num_heads)
    if cache is None:
      k = common_attention.split_heads(k, num_heads)
      v = common_attention.split_heads(v, num_heads)

    key_depth_per_head = total_key_depth // num_heads
    if not vars_3d:
      if scale_dotproduct:
        q *= key_depth_per_head**-0.5

    additional_returned_value = None
    if callable(attention_type):  # Generic way to extend multihead_attention
      x = attention_type(q, k, v, **kwargs)
      if isinstance(x, tuple):
        x, additional_returned_value = x  # Unpack
    elif attention_type == "dot_product":
      x = common_attention.dot_product_attention(
          q, k, v, bias, dropout_rate, image_shapes,
          save_weights_to=save_weights_to,
          make_image_summary=make_image_summary,
          dropout_broadcast_dims=dropout_broadcast_dims)
    elif attention_type == "dot_product_relative":
      x = common_attention.dot_product_attention_relative(
          q,
          k,
          v,
          bias,
          max_relative_position,
          dropout_rate,
          image_shapes,
          make_image_summary=make_image_summary)
    elif attention_type == "dot_product_relative_v2":
      x = common_attention.dot_product_self_attention_relative_v2(
          q,
          k,
          v,
          bias,
          max_length,
          dropout_rate,
          image_shapes,
          make_image_summary=make_image_summary,
          dropout_broadcast_dims=dropout_broadcast_dims)
    elif attention_type == "local_within_block_mask_right":
      x = common_attention.masked_within_block_local_attention_1d(
          q, k, v, block_length=block_length)
    elif attention_type == "rel_local_mask_right":
      x = common_attention.masked_rel_local_attention_1d(
          q, k, v, block_length=block_length,
          make_image_summary=make_image_summary,
          dropout_rate=dropout_rate,
          share_rel_embed=shared_rel)
    elif attention_type == "local_mask_right":
      x = common_attention.masked_local_attention_1d(
          q,
          k,
          v,
          block_length=block_length,
          make_image_summary=make_image_summary)
    elif attention_type == "local_unmasked":
      x = common_attention.local_attention_1d(
          q, k, v, block_length=block_length, filter_width=block_width)
    elif attention_type == "masked_dilated_1d":
      x = common_attention.masked_dilated_self_attention_1d(
          q, k, v, block_length, block_width,
          gap_size, num_memory_blocks)
    else:
      assert attention_type == "unmasked_dilated_1d"
      x = common_attention.dilated_self_attention_1d(
          q, k, v, block_length, block_width,
          gap_size, num_memory_blocks)
    x = common_attention.combine_heads(x)

    # Set last dim specifically.
    x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])

    if vars_3d:
      o_var = tf.get_variable(
          "o", [num_heads, total_value_depth // num_heads, output_depth])
      o_var = tf.cast(o_var, x.dtype)
      o_var = tf.reshape(o_var, [total_value_depth, output_depth])
      x = tf.tensordot(x, o_var, axes=1)
    else:
      x = common_layers.dense(
          x, output_depth, use_bias=False, name="output_transform")
    if additional_returned_value is not None:
      return x, additional_returned_value
    return x
Example #21
0
def multihead_attention(query_antecedent,
                        memory_antecedent,
                        bias,
                        total_key_depth,
                        total_value_depth,
                        output_depth,
                        num_heads,
                        dropout_rate,
                        shared_rel=False,
                        max_relative_position=None,
                        image_shapes=None,
                        attention_type="dot_product",
                        block_length=128,
                        block_width=128,
                        q_filter_width=1,
                        kv_filter_width=1,
                        q_padding="VALID",
                        kv_padding="VALID",
                        cache=None,
                        gap_size=0,
                        num_memory_blocks=2,
                        name="multihead_attention",
                        save_weights_to=None,
                        make_image_summary=True,
                        dropout_broadcast_dims=None,
                        max_length=None,
                        vars_3d=False,
                        scale_dotproduct=True,
                        **kwargs):
    """Multihead scaled-dot-product attention with input/output transformations.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
    bias: bias Tensor (see attention_bias())
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    shared_rel: boolean to share relative embeddings
    max_relative_position: Maximum distance between inputs to generate
                           unique relation embeddings for. Only relevant
                           when using "dot_product_relative" attention.
    image_shapes: optional tuple of integer scalars.
                  see comments for attention_image_summary()
    attention_type: a string, either "dot_product", "dot_product_relative",
                    "local_mask_right", "local_unmasked", "masked_dilated_1d",
                    "unmasked_dilated_1d", graph, or any attention function
                    with the signature (query, key, value, **kwargs)
    block_length: an integer - relevant for "local_mask_right"
    block_width: an integer - relevant for "local_unmasked"
    q_filter_width: An integer specifying how wide you want the query to be.
    kv_filter_width: An integer specifying how wide you want the keys and values
                     to be.
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
               kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
               no padding.
    cache: dict containing Tensors which are the results of previous
           attentions, used for fast decoding. Expects the dict to contrain two
           keys ('k' and 'v'), for the initial call the values for these keys
           should be empty Tensors of the appropriate shape.
               'k' [batch_size, 0, key_channels]
               'v' [batch_size, 0, value_channels]
    gap_size: Integer option for dilated attention to indicate spacing between
              memory blocks.
    num_memory_blocks: Integer option to indicate how many memory blocks to look
                       at.
    name: an optional string.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    max_length: an integer - needed by relative attention
    vars_3d: use 3-dimensional variables for input/output transformations
    scale_dotproduct: whether to normalize the attention product.
    **kwargs (dict): Parameters for the attention function

  Caching:
    WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
    the caching assumes that the bias contains future masking.

    The caching works by saving all the previous key and value values so that
    you are able to send just the last query location to this attention
    function. I.e. if the cache dict is provided it assumes the query is of the
    shape [batch_size, 1, hidden_dim] rather than the full memory.

  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, hidden_dim]
    unless the cache dict is provided in which case only the last memory
    position is calculated and the output shape is [batch_size, 1, hidden_dim]
    Optionally returns an additional loss parameters (ex: load balance loss for
    the experts) returned by the attention_type function.

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
    if total_key_depth % num_heads != 0:
        raise ValueError("Key depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_key_depth, num_heads))
    if total_value_depth % num_heads != 0:
        raise ValueError("Value depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_value_depth, num_heads))
    vars_3d_num_heads = num_heads if vars_3d else 0
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):

        if cache is None or memory_antecedent is None:
            q, k, v = common_attention.compute_qkv(
                query_antecedent,
                memory_antecedent,
                total_key_depth,
                total_value_depth,
                q_filter_width,
                kv_filter_width,
                q_padding,
                kv_padding,
                vars_3d_num_heads=vars_3d_num_heads)
        if cache is not None:
            if attention_type != "dot_product":
                # TODO(petershaw): Support caching when using relative position
                # representations, i.e. "dot_product_relative" attention.
                raise NotImplementedError(
                    "Caching is not guaranteed to work with attention types other than"
                    " dot_product.")
            if bias is None:
                raise ValueError(
                    "Bias required for caching. See function docstring "
                    "for details.")

            if memory_antecedent is not None:
                # Encoder-Decoder Attention Cache
                q = common_attention.compute_attention_component(
                    query_antecedent,
                    total_key_depth,
                    q_filter_width,
                    q_padding,
                    "q",
                    vars_3d_num_heads=vars_3d_num_heads)
                k = cache["k_encdec"]
                v = cache["v_encdec"]
            else:
                k = common_attention.split_heads(k, num_heads)
                v = common_attention.split_heads(v, num_heads)
                decode_loop_step = kwargs.get("decode_loop_step")
                if decode_loop_step is None:
                    k = cache["k"] = tf.concat([cache["k"], k], axis=2)
                    v = cache["v"] = tf.concat([cache["v"], v], axis=2)
                else:
                    # Inplace update is required for inference on TPU.
                    # Inplace_ops only supports inplace_update on the first dimension.
                    # The performance of current implementation is better than updating
                    # the tensor by adding the result of matmul(one_hot,
                    # update_in_current_step)
                    tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3])
                    tmp_k = inplace_ops.alias_inplace_update(
                        tmp_k, decode_loop_step, tf.squeeze(k, axis=2))
                    k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3])
                    tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3])
                    tmp_v = inplace_ops.alias_inplace_update(
                        tmp_v, decode_loop_step, tf.squeeze(v, axis=2))
                    v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3])

        q = common_attention.split_heads(q, num_heads)
        if cache is None:
            k = common_attention.split_heads(k, num_heads)
            v = common_attention.split_heads(v, num_heads)

        key_depth_per_head = total_key_depth // num_heads
        if not vars_3d:
            if scale_dotproduct:
                q *= key_depth_per_head**-0.5

        additional_returned_value = None
        if callable(
                attention_type):  # Generic way to extend multihead_attention
            x = attention_type(q, k, v, **kwargs)
            if isinstance(x, tuple):
                x, additional_returned_value = x  # Unpack
        elif attention_type == "dot_product":
            x = common_attention.dot_product_attention(
                q,
                k,
                v,
                bias,
                dropout_rate,
                image_shapes,
                save_weights_to=save_weights_to,
                make_image_summary=make_image_summary,
                dropout_broadcast_dims=dropout_broadcast_dims)
        elif attention_type == "dot_product_relative":
            x = common_attention.dot_product_attention_relative(
                q,
                k,
                v,
                bias,
                max_relative_position,
                dropout_rate,
                image_shapes,
                make_image_summary=make_image_summary)
        elif attention_type == "dot_product_relative_v2":
            x = common_attention.dot_product_self_attention_relative_v2(
                q,
                k,
                v,
                bias,
                max_length,
                dropout_rate,
                image_shapes,
                make_image_summary=make_image_summary,
                dropout_broadcast_dims=dropout_broadcast_dims)
        elif attention_type == "local_within_block_mask_right":
            x = common_attention.masked_within_block_local_attention_1d(
                q, k, v, block_length=block_length)
        elif attention_type == "rel_local_mask_right":
            x = common_attention.masked_rel_local_attention_1d(
                q,
                k,
                v,
                block_length=block_length,
                make_image_summary=make_image_summary,
                dropout_rate=dropout_rate,
                share_rel_embed=shared_rel)
        elif attention_type == "local_mask_right":
            x = common_attention.masked_local_attention_1d(
                q,
                k,
                v,
                block_length=block_length,
                make_image_summary=make_image_summary)
        elif attention_type == "local_unmasked":
            x = common_attention.local_attention_1d(q,
                                                    k,
                                                    v,
                                                    block_length=block_length,
                                                    filter_width=block_width)
        elif attention_type == "masked_dilated_1d":
            x = common_attention.masked_dilated_self_attention_1d(
                q, k, v, block_length, block_width, gap_size,
                num_memory_blocks)
        else:
            assert attention_type == "unmasked_dilated_1d"
            x = common_attention.dilated_self_attention_1d(
                q, k, v, block_length, block_width, gap_size,
                num_memory_blocks)
        x = common_attention.combine_heads(x)

        # Set last dim specifically.
        x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])

        if vars_3d:
            o_var = tf.get_variable(
                "o", [num_heads, total_value_depth // num_heads, output_depth])
            o_var = tf.cast(o_var, x.dtype)
            o_var = tf.reshape(o_var, [total_value_depth, output_depth])
            x = tf.tensordot(x, o_var, axes=1)
        else:
            x = common_layers.dense(x,
                                    output_depth,
                                    use_bias=False,
                                    name="output_transform")
        if additional_returned_value is not None:
            return x, additional_returned_value
        return x
def multihead_mpnn_attention(node_states,
                             total_key_depth,
                             total_value_depth,
                             output_depth,
                             num_heads,
                             adjacency_matrix=None,
                             num_edge_types=5,
                             num_transforms=None,
                             use_weighted_sum=False,
                             name="mpnn_attention"):
  """Multihead scaled-dot-product attention with input/output transformations.

  Let B be the number of batches.
  Let N be the number of nodes in the graph.
  Let D be the size of the node hidden states.
  Let K be the size of the attention keys/queries (total_key_depth).
  Let V be the size of the attention values (total_value_depth).
  Let O be the size of the attention output (output_depth).
  Let H be the number of heads (num_heads).
  Let T be the total number of transforms (num_transforms).

  The key and value depths are split across all of the heads. For example, if
  the key depth is 6 and there are three heads, then the key for each head has
  depth 2.

  Args:
    node_states: A Tensor with shape [B, N, D]
    total_key_depth: An integer (K).
    total_value_depth: An integer (V).
    output_depth: An integer (O).
    num_heads: An integer (H).
    adjacency_matrix: An Tensor of ints with shape [B, T, N, N]. If there is an
      edge from node j to node i in batch b, then adjacency_matrix[b, i, j]
      contains the type of that edge as an integer. Otherwise, it contains 0.
    num_edge_types: An integer indicating number of edge types.
    num_transforms: An integer indicating number of transforms (T). If None,
      then num_transforms will be equal to num_edge_types.
    use_weighted_sum: If False, will only use a single transform per edge type.
      Otherwise, use a learned weighted sum of transforms per edge type.
    name: A string.

  Returns:
    The result of the attention transformation. The output shape is [B, N, O].

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
  if total_key_depth % num_heads != 0:
    raise ValueError("Key depth (%d) must be divisible by the number of "
                     "attention heads (%d)." % (total_key_depth, num_heads))
  if total_value_depth % num_heads != 0:
    raise ValueError("Value depth (%d) must be divisible by the number of "
                     "attention heads (%d)." % (total_value_depth, num_heads))
  with tf.variable_scope(
      name, default_name="multihead_mpnn_attention", values=[node_states]):
    # If not explicitly set, use num_transforms set to num_edge_types.
    num_transforms = (
        num_edge_types if num_transforms is None else num_transforms)

    # Create the query for each node's incoming edges.
    # Create the keys/values for each node for each possible outgoing edge type.
    q, k, v = compute_mpnn_qkv(
        node_states,
        total_key_depth,
        total_value_depth,
        num_transforms)

    q_shape = tf.shape(q)  # As above, q_shape is [B, N, K].

    # Divides each query/key/value into separate heads. Specifically, the
    # query/key/value for each (batch, node) pair (i.e., the third dimensions
    # of q, k, and v) are broken into H separate pieces. These pieces are used
    # as the separate attention heads. The resulting tensors have shape
    # [B, H, N, ?/H], where ? = K, K*T or V*T as appropriate.
    q = common_attention.split_heads(q, num_heads)  # Shape [B, H, N, K/H].
    k = common_attention.split_heads(k, num_heads)  # Shape [B, H, N, K*T/H].
    v = common_attention.split_heads(v, num_heads)  # Shape [B, H, N, V*T/H].
    key_depth_per_head = total_key_depth // num_heads

    # Ensures that the logits don't have too large of a magnitude.
    q *= key_depth_per_head**-0.5

    # Rearrange the dimensions so that the head is first. This will make
    # subsequent steps easier (we loop over the head).
    q = tf.transpose(q, [1, 0, 2, 3])  # Shape [H, B, N, K/H].
    k = tf.transpose(k, [1, 0, 2, 3])  # Shape [H, B, N, K*T/H].
    v = tf.transpose(v, [1, 0, 2, 3])  # Shape [H, B, N, V*T/H].

    # Split the keys and values into separate per-edge-type keys and values.
    k = tf.reshape(k, [
        num_heads, q_shape[0], q_shape[1], num_transforms,
        total_key_depth // num_heads
    ])  # Shape [H, B, N, T, K/H].
    k = tf.transpose(k, [0, 1, 3, 2, 4])  # Shape [H, B, T, N, K/H].

    v = tf.reshape(v, [
        num_heads, q_shape[0], q_shape[1], num_transforms,
        total_value_depth // num_heads
    ])  # Shape [H, B, N, T, V/H].
    v = tf.transpose(v, [0, 1, 3, 2, 4])  # Shape [H, B, T, N, V/H].

    # Perform attention for each head and combine the results into a list.
    # head_outputs stores a list of tensors, each with shape [1, B, N, V/H].
    # The last dimension contains the values computed for each attention head.
    # Each value was determined by computing attention over all of the
    # incoming edges for node n, weighting the incoming values accordingly,
    # and adding those weighted values together.
    head_outputs = []
    for head_id in range(num_heads):
      output = dot_product_mpnn_attention(
          q[head_id],
          k[head_id],
          v[head_id],
          adjacency_matrix,
          num_edge_types,
          num_transforms=num_transforms,
          use_weighted_sum=use_weighted_sum)

      # Store this result in the list of attention results for each head.
      # The call to expand_dims gives output shape [1, B, N, V/H], which will
      # come in handy when we combine the heads together.
      head_outputs.append(tf.expand_dims(output, axis=0))

    # Combine the heads together into one tensor and rearrange the dimensions.
    x = tf.concat(head_outputs, axis=0)  # Shape [H, B, N, V/H].
    x = tf.transpose(x, [1, 0, 2, 3])  # Shape [B, H, N, V/H].

    # Concatenate the values produced by each head together into one vector.
    x = common_attention.combine_heads(x)  # Shape [B, N, V].

    # A fully-connected linear layer to convert from the value vectors of size V
    # to output vectors of length O (the appropriate output length).
    x = common_layers.dense(
        x, output_depth, use_bias=False, name="output_transform")
    return x
def multihead_graph_attention(query_antecedent,
                              memory_antecedent,
                              bias,
                              total_key_depth,
                              total_value_depth,
                              output_depth,
                              num_heads,
                              dropout_rate,
                              image_shapes=None,
                              attention_type="edge_vector",
                              name="multihead_graph_attention",
                              save_weights_to=None,
                              make_image_summary=True,
                              dropout_broadcast_dims=None,
                              adjacency_matrix=None,
                              num_edge_types=5,
                              vars_3d=False,
                              **kwargs):
  """Multihead scaled-dot-product attention with input/output transformations.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
    bias: bias Tensor (see attention_bias())
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    image_shapes: optional tuple of integer scalars.
                  see comments for attention_image_summary()
    attention_type: a string, either "dot_product", "dot_product_relative",
                    "local_mask_right", "local_unmasked", "masked_dilated_1d",
                    "unmasked_dilated_1d", graph, or any attention function
                    with the signature (query, key, value, **kwargs)
    name: an optional string.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    adjacency_matrix: an optional tensor of shape [batch, len_q, len_q]
      containing edge vectors for attention
    num_edge_types: number of edge types, an int
    vars_3d: use 3-dimensional variables for input/output transformations
    **kwargs (dict): Parameters for the attention function

  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, output_depth]

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
  if total_key_depth % num_heads != 0:
    raise ValueError("Key depth (%d) must be divisible by the number of "
                     "attention heads (%d)." % (total_key_depth, num_heads))
  if total_value_depth % num_heads != 0:
    raise ValueError("Value depth (%d) must be divisible by the number of "
                     "attention heads (%d)." % (total_value_depth, num_heads))
  vars_3d_num_heads = num_heads if vars_3d else None
  with tf.variable_scope(
      name,
      default_name="multihead_attention",
      values=[query_antecedent, memory_antecedent]):

    q, k, v = common_attention.compute_qkv(
        query_antecedent,
        memory_antecedent,
        total_key_depth,
        total_value_depth,
        vars_3d_num_heads=vars_3d_num_heads)
    q = common_attention.split_heads(q, num_heads)
    k = common_attention.split_heads(k, num_heads)
    v = common_attention.split_heads(v, num_heads)

    key_depth_per_head = total_key_depth // num_heads
    if not vars_3d:
      q *= key_depth_per_head**-0.5

    additional_returned_value = None
    if callable(attention_type):  # Generic way to extend multihead_attention
      x = attention_type(q, k, v, **kwargs)
      if isinstance(x, tuple):
        x, additional_returned_value = x  # Unpack

    elif attention_type == "edge_vector":
      x = graph_attention(
          q,
          k,
          v,
          bias,
          dropout_rate,
          image_shapes,
          save_weights_to=save_weights_to,
          make_image_summary=make_image_summary,
          dropout_broadcast_dims=dropout_broadcast_dims,
          adjacency_matrix=adjacency_matrix,
          num_edge_types=num_edge_types)

    x = common_attention.combine_heads(x)

    # Set last dim specifically.
    x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])

    if vars_3d:
      o_var = tf.get_variable(
          "o", [num_heads, total_value_depth // num_heads, output_depth])
      o_var = tf.reshape(o_var, [total_value_depth, output_depth])
      x = tf.tensordot(x, o_var, axes=1)
    else:
      x = common_layers.dense(
          x, output_depth, use_bias=False, name="output_transform")
    if additional_returned_value is not None:
      return x, additional_returned_value
    return x
def compute_mpnn_qkv(node_states,
                     total_key_depth,
                     total_value_depth,
                     num_edge_types,
                     ignore_zero=True):
    """Computes query, key and value for edge matrices.

  Let B be the number of batches.
  Let N be the number of nodes in the graph.
  Let D be the size of the node hidden states.
  Let K be the size of the attention keys/queries (total_key_depth).
  Let V be the size of the attention values (total_value_depth).
  Let T be the total number of edge types (num_edge_types).

  Computes the queries, keys, and values for attention.
  * For each node N_i in the graph, a query Q_i of size K is computed. This
    query is used to determine the relative weights to give to each of the
    node's incoming edges.
  * For each node N_j and edge type t, a key K_jt of size K is computed. When an
    edge of type t goes from node N_j to any other node, K_jt is the key that is
    in the attention process.
  * For each node N_j and edge type t, a value V_jt of size V is computed. When
    an edge of type t goes from node N_j to node N_i, Attention(Q_i, K_jt)
    produces a weight w_ijt. The message sent along this edge is w_ijt * V_jt.

  Args:
    node_states: A Tensor with shape [B, N, D].
    total_key_depth: an integer (K).
    total_value_depth: an integer (V).
    num_edge_types: a integer specifying number of edge types (T).
    ignore_zero: If true, then edge type 0 will not be considered. Equivalent
      to having a linear transformation of all 0's for edge type 0. All queries,
      keys, and values for edge type 0 will be all 0's.
  Returns:
    q: The attention queries for each destination node (shape [B, N, K]).
    k: The attention keys for each node and edge type (shape [B, N*T, K]).
    v: The attention values for each node and edge type (shape [B, N*T, V]).
  """

    # node_states is initially a tensor with shape [B, N, D]. The call to dense
    # creates a D x K kernel that serves as a fully-connected layer.
    #
    # For each possible batch b and node n in the first two dimensions of
    # node_states, the corresponding size-D vector (the third dimension of
    # node_states) is the hidden state for node n in batch b. Each of these size-D
    # vectors is multiplied by the kernel to produce an attention query of size K.
    # The result is a tensor of size [B, N, K] containing the attention queries
    # for each node in each batch.
    q = common_layers.dense(node_states,
                            total_key_depth,
                            use_bias=False,
                            name="q_mpnn")

    q_shape = common_layers.shape_list(q)  # As above, q_shape = [B, N, K].

    # T (or T-1 if ignore_zero).
    nonignored_edge_types = num_edge_types - int(ignore_zero)

    # Creates the attention keys in a manner similar to the process of creating
    # the attention queries. One key is created for each type of outgoing edge the
    # corresponding node might have, meaning k will have shape [B, N, K*T].
    k = common_layers.dense(node_states,
                            total_key_depth * nonignored_edge_types,
                            use_bias=False,
                            name="k_mpnn")

    # The values over which self-attention is performed. They are created in
    # a manner largely identical to that of the keys.
    v = common_layers.dense(node_states,
                            total_value_depth * nonignored_edge_types,
                            use_bias=False,
                            name="v_mpnn")

    batch = q_shape[0]  # B.
    length = q_shape[1]  # N.

    # Making the fourth dimension explicit by separating the vectors of size
    # K*T (in k) and V*T (in v) into two-dimensional matrices with shape [K, T]
    # (in k) and [V, T] in v.
    #
    # This reshape is only necessary when ignore_zero is True (for the padding
    # step that follows).
    k = tf.reshape(k, [batch, length, nonignored_edge_types, total_key_depth])
    v = tf.reshape(
        v, [q_shape[0], q_shape[1], nonignored_edge_types, total_value_depth])

    # If we previously ignored edge type 0, then we need to pad the keys and
    # values to take this additional edge type into account. To do so, we
    # pad the third dimension of k and v (which has size T-1 if ignore_zero is
    # True) to size T with zeroes.
    if ignore_zero:
        k = tf.pad(k, [[0, 0], [0, 0], [1, 0], [0, 0]])
        v = tf.pad(v, [[0, 0], [0, 0], [1, 0], [0, 0]])

    # Flatten out the fourth dimension.
    k = tf.reshape(k,
                   [q_shape[0], q_shape[1] * num_edge_types, total_key_depth])
    v = tf.reshape(
        v, [q_shape[0], q_shape[1] * num_edge_types, total_value_depth])
    return q, k, v
Example #25
0
def compute_attention_component(antecedent,
                                total_depth,
                                filter_width=1,
                                padding="VALID",
                                name="c",
                                vars_3d_num_heads=0,
                                sparsity_technique=None,
                                threshold=3.0,
                                training=True,
                                clip_alpha=None,
                                initial_sparsity=None,
                                split_heads=False,
                                num_heads=None):
    """Computes attention compoenent (query, key or value).

  Args:
    antecedent: a Tensor with shape [batch, length, channels]
    total_depth: an integer
    filter_width: An integer specifying how wide you want the attention
      component to be.
    padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
    name: a string specifying scope name.
    vars_3d_num_heads: an optional integer (if we want to use 3d variables)
    sparsity_technique: technique used for sparsifying weights.
    threshold: log alpha threshold used for evaluation with variational dropout.
    training: whether model is being trained or not.
    clip_alpha: alpha clipping threshold for variational dropout.
    initial_sparsity: initial sparsity level for lottery ticket &
      scratch experiments.
    split_heads: Whether to prune each head separately.
    num_heads: The number of heads in the attention module.

  Returns:
    c : [batch, length, depth] tensor
  """
    # We don't support 3d attention variables or filter_width > 1 with sparsity
    # techniques
    assert not sparsity_technique or (not vars_3d_num_heads
                                      and filter_width == 1)

    if vars_3d_num_heads > 0:
        assert filter_width == 1
        input_depth = antecedent.get_shape().as_list()[-1]
        depth_per_head = total_depth // vars_3d_num_heads
        initializer_stddev = input_depth**-0.5
        if "q" in name:
            initializer_stddev *= depth_per_head**-0.5
        var = tf.get_variable(
            name,
            [input_depth, vars_3d_num_heads, total_depth // vars_3d_num_heads],
            initializer=tf.random_normal_initializer(
                stddev=initializer_stddev))
        var = tf.cast(var, antecedent.dtype)
        var = tf.reshape(var, [input_depth, total_depth])
        return tf.tensordot(antecedent, var, axes=1)
    if filter_width == 1:
        if sparsity_technique:
            if split_heads:
                # Prune each heads weights separately so that they are free
                # to have different weight magnitude distributions.
                if num_heads is None:
                    raise ValueError(
                        "`num_heads` must be set for split head pruning.")
                if total_depth % num_heads != 0:
                    raise ValueError(
                        "`total_depth` must be divisible by `num_heads`.")
                input_depth = antecedent.get_shape().as_list()[-1]
                depth_per_head = int(total_depth / num_heads)
                masked_head_weights = []
                for head_id in range(num_heads):
                    head_name = name + "_shard_{}".format(head_id)
                    with tf.variable_scope(head_name) as vs:
                        head_weights = tf.get_variable(
                            "kernel", [input_depth, depth_per_head])
                        masked_head_weights.append(
                            pruning.apply_mask(head_weights, vs))
                component_weights = tf.concat(masked_head_weights, axis=1)

                # compute the full component result
                return tf.tensordot(antecedent, component_weights, axes=1)
            else:
                return common_sparse.dense(
                    antecedent,
                    total_depth,
                    use_bias=False,
                    sparsity_technique=sparsity_technique,
                    threshold=threshold,
                    training=training,
                    clip_alpha=clip_alpha,
                    name=name,
                    initial_sparsity=initial_sparsity)
        else:
            return common_layers.dense(antecedent,
                                       total_depth,
                                       use_bias=False,
                                       name=name)
    else:
        return common_layers.conv1d(antecedent,
                                    total_depth,
                                    filter_width,
                                    padding=padding,
                                    name=name)
Example #26
0
def multihead_attention(query_antecedent,
                        memory_antecedent,
                        bias,
                        total_key_depth,
                        total_value_depth,
                        output_depth,
                        num_heads,
                        dropout_rate,
                        attention_type="dot_product",
                        image_shapes=None,
                        q_filter_width=1,
                        kv_filter_width=1,
                        q_padding="VALID",
                        kv_padding="VALID",
                        cache=None,
                        name="multihead_attention",
                        save_weights_to=None,
                        make_image_summary=True,
                        dropout_broadcast_dims=None,
                        vars_3d=False,
                        sparsity_technique=None,
                        threshold=3.0,
                        training=True,
                        clip_alpha=None,
                        initial_sparsity=None,
                        split_heads=False,
                        **kwargs):
    """Multihead scaled-dot-product attention with input/output transformations.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
    bias: bias Tensor (see attention_bias())
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    attention_type: a string, either "dot_product", "dot_product_relative",
                    "local_mask_right", "local_unmasked", "masked_dilated_1d",
                    "unmasked_dilated_1d", graph, or any attention function
                    with the signature (query, key, value, **kwargs)
    image_shapes: optional tuple of integer scalars.
                  see comments for attention_image_summary()
    q_filter_width: An integer specifying how wide you want the query to be.
    kv_filter_width: An integer specifying how wide you want the keys and values
                     to be.
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
               kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
               no padding.
    cache: dict containing Tensors which are the results of previous
           attentions, used for fast decoding. Expects the dict to contrain two
           keys ('k' and 'v'), for the initial call the values for these keys
           should be empty Tensors of the appropriate shape.
               'k' [batch_size, 0, key_channels]
               'v' [batch_size, 0, value_channels]
    name: an optional string.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    vars_3d: use 3-dimensional variables for input/output transformations
    sparsity_technique: technique used for sparsifying weights.
    threshold: log alpha threshold used for evaluation with variational dropout.
    training: whether model is being trained or not.
    clip_alpha: alpha clipping threshold for variational dropout.
    initial_sparsity: initial sparsity level for lottery ticket &
      scratch experiments.
    split_heads: Whether to prune each head separately.
    **kwargs (dict): Parameters for the attention function

  Caching:
    WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
    the caching assumes that the bias contains future masking.

    The caching works by saving all the previous key and value values so that
    you are able to send just the last query location to this attention
    function. I.e. if the cache dict is provided it assumes the query is of the
    shape [batch_size, 1, hidden_dim] rather than the full memory.

  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, hidden_dim]
    unless the cache dict is provided in which case only the last memory
    position is calculated and the output shape is [batch_size, 1, hidden_dim]
    Optionally returns an additional loss parameters (ex: load balance loss for
    the experts) returned by the attention_type function.

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
    if total_key_depth % num_heads != 0:
        raise ValueError("Key depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_key_depth, num_heads))
    if total_value_depth % num_heads != 0:
        raise ValueError("Value depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_value_depth, num_heads))
    if vars_3d:
        raise ValueError("3d attention variables not supported.")
    if attention_type != "dot_product":
        raise ValueError(
            "Sparse multihead attention only supports dot_product attention.")

    vars_3d_num_heads = 0
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):

        if cache is None or memory_antecedent is None:
            q, k, v = compute_qkv(query_antecedent,
                                  memory_antecedent,
                                  total_key_depth,
                                  total_value_depth,
                                  q_filter_width,
                                  kv_filter_width,
                                  q_padding,
                                  kv_padding,
                                  vars_3d_num_heads=vars_3d_num_heads,
                                  sparsity_technique=sparsity_technique,
                                  threshold=threshold,
                                  training=training,
                                  clip_alpha=clip_alpha,
                                  initial_sparsity=initial_sparsity,
                                  split_heads=split_heads,
                                  num_heads=num_heads)
        if cache is not None:
            if bias is None:
                raise ValueError(
                    "Bias required for caching. See function docstring "
                    "for details.")

            if memory_antecedent is not None:
                # Encoder-Decoder Attention Cache
                q = compute_attention_component(
                    query_antecedent,
                    total_key_depth,
                    q_filter_width,
                    q_padding,
                    "q",
                    vars_3d_num_heads=vars_3d_num_heads,
                    sparsity_technique=sparsity_technique,
                    threshold=threshold,
                    training=training,
                    clip_alpha=clip_alpha,
                    initial_sparsity=initial_sparsity,
                    split_heads=split_heads,
                    num_heads=num_heads)
                k = cache["k_encdec"]
                v = cache["v_encdec"]
            else:
                k = common_attention.split_heads(k, num_heads)
                v = common_attention.split_heads(v, num_heads)
                decode_loop_step = kwargs.get("decode_loop_step")
                if decode_loop_step is None:
                    k = cache["k"] = tf.concat([cache["k"], k], axis=2)
                    v = cache["v"] = tf.concat([cache["v"], v], axis=2)
                else:
                    # Inplace update is required for inference on TPU.
                    # Inplace_ops only supports inplace_update on the first dimension.
                    # The performance of current implementation is better than updating
                    # the tensor by adding the result of matmul(one_hot,
                    # update_in_current_step)
                    tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3])
                    tmp_k = inplace_ops.alias_inplace_update(
                        tmp_k, decode_loop_step, tf.squeeze(k, axis=2))
                    k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3])
                    tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3])
                    tmp_v = inplace_ops.alias_inplace_update(
                        tmp_v, decode_loop_step, tf.squeeze(v, axis=2))
                    v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3])

        q = common_attention.split_heads(q, num_heads)
        if cache is None:
            k = common_attention.split_heads(k, num_heads)
            v = common_attention.split_heads(v, num_heads)

        key_depth_per_head = total_key_depth // num_heads
        if not vars_3d:
            q *= key_depth_per_head**-0.5

        # compute the attention
        x = common_attention.dot_product_attention(
            q,
            k,
            v,
            bias,
            dropout_rate,
            image_shapes,
            save_weights_to=save_weights_to,
            make_image_summary=make_image_summary,
            dropout_broadcast_dims=dropout_broadcast_dims)
        x = common_attention.combine_heads(x)

        # Set last dim specifically.
        x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])

        if sparsity_technique:
            x = common_sparse.dense(x,
                                    output_depth,
                                    use_bias=False,
                                    sparsity_technique=sparsity_technique,
                                    threshold=threshold,
                                    training=training,
                                    clip_alpha=clip_alpha,
                                    name="output_transform",
                                    initial_sparsity=initial_sparsity)
        else:
            x = common_layers.dense(x,
                                    output_depth,
                                    use_bias=False,
                                    name="output_transform")
        return x
def multihead_mpnn_attention(node_states,
                             total_key_depth,
                             total_value_depth,
                             output_depth,
                             num_heads,
                             adjacency_matrix=None,
                             num_edge_types=5,
                             ignore_zero=True,
                             name="mpnn_attention"):
    """Multihead scaled-dot-product attention with input/output transformations.

  Args:
    node_states: A tensor of shape [batch, length, depth]
    total_key_depth: An integer for key dimension
    total_value_depth: An integer for value dimensions
    output_depth: An intger for output dimemsions
    num_heads: An integer
    adjacency_matrix: An tensor of ints of shape [batch, length, length]
    num_edge_types: An integer indicating number of edge bins
    ignore_zero: A flag that says that edge type 0 should be ignored
    name: A string

  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, output_depth]
    unless the cache dict is provided in which case only the last memory
    position is calculated and the output shape is [batch_size, 1, hidden_dim]
    Optionaly returns an additional loss parameters (ex: load balance loss for
    the experts) returned by the attention_type function.

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
    if total_key_depth % num_heads != 0:
        raise ValueError("Key depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_key_depth, num_heads))
    if total_value_depth % num_heads != 0:
        raise ValueError("Value depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_value_depth, num_heads))
    with tf.variable_scope(name,
                           default_name="multihead_mpnn_attention",
                           values=[node_states]):
        q, k, v = compute_mpnn_qkv(node_states,
                                   total_key_depth,
                                   total_value_depth,
                                   num_edge_types,
                                   ignore_zero=ignore_zero)
        # reshaping k and v for head splitting
        q_shape = tf.shape(q)
        q = common_attention.split_heads(q, num_heads)
        k = common_attention.split_heads(k, num_heads)
        v = common_attention.split_heads(v, num_heads)
        key_depth_per_head = total_key_depth // num_heads
        q *= key_depth_per_head**-0.5
        # make the heads dimension leading. We will loop over heads.
        q = tf.transpose(q, [1, 0, 2, 3])
        k = tf.transpose(k, [1, 0, 2, 3])
        v = tf.transpose(v, [1, 0, 2, 3])
        # putting edge as the dimension after batch for k and v
        # k and v will be [heads, batch, num_edge_types, length, depth]
        k = tf.reshape(k, [
            num_heads, q_shape[0], q_shape[1], num_edge_types,
            total_key_depth // num_heads
        ])
        k = tf.transpose(k, [0, 1, 3, 2, 4])

        v = tf.reshape(v, [
            num_heads, q_shape[0], q_shape[1], num_edge_types,
            total_value_depth // num_heads
        ])
        v = tf.transpose(v, [0, 1, 3, 2, 4])

        # doing attention separately for each head
        head_outputs = []
        for head_id in range(num_heads):
            output = dot_product_mpnn_attention(q[head_id], k[head_id],
                                                v[head_id], adjacency_matrix,
                                                num_edge_types)
            head_outputs.append(tf.expand_dims(output, axis=0))
        # making x = [heads, batch, length, total_value_depth//num_heads]
        x = tf.concat(head_outputs, axis=0)
        x = tf.transpose(x, [1, 0, 2, 3])
        # making x [batch, length, depth]
        x = common_attention.combine_heads(x)
        x = common_layers.dense(x,
                                output_depth,
                                use_bias=False,
                                name="output_transform")
        return x
 def _compute(inp, depth, filter_width, padding, name):
   if filter_width == 1:
     return common_layers.dense(inp, depth, use_bias=False, name=name)
   else:
     return common_layers.conv1d(inp, depth, filter_width, padding, name=name)
def hierarchical_attention_network_encoder(
        encoder_input,
        encoder_self_attention_bias,
        contexts,
        context_self_attention_biases,
        features,
        hparams,
        name="hierarchical_attention_network_encoder",
        save_weights_to=None,
        make_image_summary=True,
        losses=None):
    input_x = encoder_input
    context_xs = {}
    for context_name in contexts:
        context_xs[context_name] = contexts[context_name]
    context_paddings = {}
    context_nonpaddings = {}
    context_pad_removers = {}

    attention_dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(hparams, "attention_dropout_broadcast_dims", "")))

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        input_padding = common_attention.attention_bias_to_padding(
            encoder_self_attention_bias)
        input_nonpadding = 1.0 - input_padding
        for context_name in context_self_attention_biases:
            context_paddings[
                context_name] = common_attention.attention_bias_to_padding(
                    context_self_attention_biases[context_name])
            context_nonpaddings[
                context_name] = 1.0 - context_paddings[context_name]

        input_pad_remover = None
        for context_name in context_paddings:
            context_pad_removers[context_name] = None
        if hparams.use_pad_remover and not common_layers.is_xla_compiled():
            input_pad_remover = expert_utils.PadRemover(input_padding)
            for context_name in context_paddings:
                context_pad_removers[context_name] = expert_utils.PadRemover(
                    context_paddings[context_name])

        temp_hparam = tf.contrib.training.HParams(
        )  # copy hparams except num_hidden_layers -> num_hidden_layers - 1
        for key, val in hparams.values().items():
            temp_hparam.add_hparam(key, val)
        temp_hparam.set_hparam("num_hidden_layers",
                               hparams.num_hidden_layers - 1)
        encoder_output = transformer_with_contexts_layers.transformer_encoder(
            input_x,
            encoder_self_attention_bias,
            temp_hparam,
            nonpadding=features_to_nonpadding(features, "inputs"),
            save_weights_to=save_weights_to,
            make_image_summary=make_image_summary)

        context_encoded_outputs = {}
        for context_name in context_xs:
            context_encoded_outputs[
                context_name] = transformer_with_contexts_layers.transformer_encoder(
                    context_xs[context_name],
                    context_self_attention_biases[context_name],
                    hparams,
                    nonpadding=features_to_nonpadding(features, context_name),
                    save_weights_to=save_weights_to,
                    make_image_summary=make_image_summary)

        with tf.variable_scope('word_abstraction', reuse=tf.AUTO_REUSE):
            encoder_word_level_query = common_layers.dense(
                encoder_output, hparams.hidden_size)  # q_w = f_w(h_t)
            encoder_word_level_abstraction = {}
            for context_name in context_encoded_outputs:
                encoder_word_level_abstraction[
                    context_name] = transformer_with_contexts_layers.multihead_attention(
                        common_layers.layer_preprocess(
                            encoder_word_level_query, hparams),
                        context_encoded_outputs[context_name],
                        context_self_attention_biases[context_name],
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        save_weights_to=save_weights_to,
                        make_image_summary=make_image_summary,
                        max_relative_position=hparams.max_relative_position,
                        dropout_broadcast_dims=attention_dropout_broadcast_dims,
                        max_length=hparams.get("max_length"),
                        vars_3d=hparams.get("attention_variables_3d"))  # s^j,

            sentence_information = tf.concat([
                encoder_word_level_abstraction[context_name]
                for context_name in encoder_word_level_abstraction
            ],
                                             axis=1)

        with tf.variable_scope('sentence_abstraction', reuse=tf.AUTO_REUSE):
            encoder_sentence_level_query = common_layers.dense(
                encoder_output, hparams.hidden_size)  # q_s = f_s(h_t)
            context_padding = common_attention.embedding_to_padding(
                sentence_information)
            ignore_padding = common_attention.attention_bias_ignore_padding(
                context_padding)
            contextual_information = transformer_with_contexts_layers.multihead_attention(
                common_layers.layer_preprocess(encoder_sentence_level_query,
                                               hparams),
                sentence_information,
                ignore_padding,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                attention_type=hparams.self_attention_type,
                save_weights_to=save_weights_to,
                make_image_summary=make_image_summary,
                max_relative_position=hparams.max_relative_position,
                dropout_broadcast_dims=attention_dropout_broadcast_dims,
                max_length=hparams.get("max_length"),
                vars_3d=hparams.get("attention_variables_3d")
            )  # MultiHead(q_s, s^j), [batch, encoder_length, hidden_dim]

            contextual_information = common_layers.dense_relu_dense(
                contextual_information, hparams.filter_size,
                hparams.hidden_size)

        with tf.variable_scope('context_gating', reuse=tf.AUTO_REUSE):
            gate_lambda = tf.nn.sigmoid(
                common_layers.dense(contextual_information,
                                    hparams.hidden_size) +
                common_layers.dense(encoder_output, hparams.hidden_size))
            encoder_output = gate_lambda * encoder_output + (
                1 - gate_lambda) * contextual_information

    return common_layers.layer_preprocess(encoder_output, hparams)
def compute_mpnn_qkv(node_states,
                     total_key_depth,
                     total_value_depth,
                     num_transforms):
  """Computes query, key and value for edge matrices.

  Let B be the number of batches.
  Let N be the number of nodes in the graph.
  Let D be the size of the node hidden states.
  Let K be the size of the attention keys/queries (total_key_depth).
  Let V be the size of the attention values (total_value_depth).
  Let T be the total number of transforms (num_transforms).

  Computes the queries, keys, and values for attention.
  * For each node N_i in the graph, a query Q_i of size K is computed. This
    query is used to determine the relative weights to give to each of the
    node's incoming edges.
  * For each node N_j and edge type t, a key K_jt of size K is computed. When an
    edge of type t goes from node N_j to any other node, K_jt is the key that is
    in the attention process.
  * For each node N_j and edge type t, a value V_jt of size V is computed. When
    an edge of type t goes from node N_j to node N_i, Attention(Q_i, K_jt)
    produces a weight w_ijt. The message sent along this edge is w_ijt * V_jt.

  Args:
    node_states: A Tensor with shape [B, N, D].
    total_key_depth: an integer (K).
    total_value_depth: an integer (V).
    num_transforms: a integer specifying number of transforms (T). This is
      typically the number of edge types.
  Returns:
    q: The attention queries for each destination node (shape [B, N, K]).
    k: The attention keys for each node and edge type (shape [B, N*T, K]).
    v: The attention values for each node and edge type (shape [B, N*T, V]).
  """

  # node_states is initially a tensor with shape [B, N, D]. The call to dense
  # creates a D x K kernel that serves as a fully-connected layer.
  #
  # For each possible batch b and node n in the first two dimensions of
  # node_states, the corresponding size-D vector (the third dimension of
  # node_states) is the hidden state for node n in batch b. Each of these size-D
  # vectors is multiplied by the kernel to produce an attention query of size K.
  # The result is a tensor of size [B, N, K] containing the attention queries
  # for each node in each batch.
  q = common_layers.dense(
      node_states, total_key_depth, use_bias=False, name="q_mpnn")

  # Creates the attention keys in a manner similar to the process of creating
  # the attention queries. One key is created for each type of outgoing edge the
  # corresponding node might have, meaning k will have shape [B, N, K*T].
  k = _compute_edge_transforms(node_states,
                               total_key_depth,
                               num_transforms,
                               name="k_mpnn")
  v = _compute_edge_transforms(node_states,
                               total_value_depth,
                               num_transforms,
                               name="v_mpnn")

  return q, k, v
  def body(self, features):
    hp = self.hparams
    # pylint: disable=eval-used
    if hp.image_input_type == "image":
      image_feat = vqa_layers.image_embedding(
          features["inputs"],
          model_fn=eval(hp.image_model_fn),
          trainable=hp.train_resnet,
          is_training=hp.mode == tf.estimator.ModeKeys.TRAIN)
    else:
      image_feat = features["inputs"]

    image_feat = common_layers.flatten4d3d(image_feat)
    image_hidden_size = hp.image_hidden_size or hp.hidden_size
    if hp.image_feat_preprocess_proj:
      image_feat = common_layers.dense(image_feat, image_hidden_size)
      utils.collect_named_outputs("norms", "image_feat_after_proj",
                                  tf.norm(image_feat, axis=-1))
    else:
      assert image_hidden_size == 2048

    image_feat = tf.nn.dropout(
        image_feat, keep_prob=1.-hp.layer_prepostprocess_dropout)

    if hp.image_feat_encode:
      image_feat = image_encoder(image_feat, hp)
      utils.collect_named_outputs("norms", "image_feat_encoded",
                                  tf.norm(image_feat, axis=-1))
    else:
      image_feat = common_layers.layer_norm(image_feat)
      utils.collect_named_outputs("norms", "image_feat_after_layer",
                                  tf.norm(image_feat, axis=-1))

    question = common_layers.flatten4d3d(features["question"])
    utils.collect_named_outputs("norms", "question_embedding",
                                tf.norm(question, axis=-1))
    question, question_self_attention_bias = prepare_question_encoder(
        question, hp)
    question = tf.nn.dropout(
        question, keep_prob=1.-hp.layer_prepostprocess_dropout)
    query = question_encoder(question, question_self_attention_bias, hp)
    utils.collect_named_outputs(
        "norms", "query_encode", tf.norm(query, axis=-1))
    query = (query + tf.expand_dims(
        tf.squeeze(question_self_attention_bias, [1, 2]), axis=2))
    query = tf.reduce_max(query, axis=1)
    utils.collect_named_outputs(
        "norms", "query_maxpool", tf.norm(query, axis=-1))

    # query = common_layers.l2_norm(query)
    # utils.collect_named_outputs("norms", "query_after_l2",
    #                             tf.norm(query, axis=-1))

    image_ave = attn(image_feat, query, hp)
    utils.collect_named_outputs("norms", "image_ave",
                                tf.norm(image_ave, axis=-1))

    if hp.multimodal_combine == "concat":
      image_question = tf.concat([image_ave, query], axis=1)
    elif hp.multimodal_combine == "sum":
      image_question = image_ave + query
    elif hp.multimodal_combine == "product":
      image_question = image_ave * query

    utils.collect_named_outputs("norms", "image_question",
                                tf.norm(image_question, axis=-1))

    image_question = tf.nn.dropout(image_question, 1. - hp.dropout)

    output = mlp(image_question, hp)
    utils.collect_named_outputs("norms", "output",
                                tf.norm(output, axis=-1))

    norm_tensors = utils.convert_collection_to_dict("norms")
    vqa_layers.summarize_tensors(norm_tensors, tag="norms/")

    # Expand dimension 1 and 2
    return tf.expand_dims(tf.expand_dims(output, axis=1), axis=2)
Example #32
0
def _ffn_layer_multi_inputs(inputs_list,
                            hparams,
                            ffn_layer_type="dense",
                            name="ffn",
                            kernel_initializer=None,
                            bias_initializer=None,
                            activation=None,
                            pad_remover=None,
                            preprocess=False,
                            postprocess=False):
    """Implements a Feed-forward layer with multiple inputs, pad-removing, etc.
  Args:
    inputs_list: list of input tensors
    hparams: hyper-parameters
    ffn_layer_type: dense / dense_dropconnect/ dense_relu_dense
    name: name
    kernel_initializer: kernel initializer
    bias_initializer: bias initializer
    activation: activation function
    pad_remover: pad remover
    preprocess: if preprocess the input
    postprocess: if postprocess the output
  Returns:
    a tensor
  Raises:
    ValueError: Unknown ffn_layer type.
  """

    # need at least one inputs
    num_inputs = len(inputs_list)
    assert num_inputs > 0

    if preprocess and num_inputs == 1:
        inputs_list[0] = common_layers.layer_preprocess(
            inputs_list[0], hparams)

    if postprocess:
        original_inputs = inputs_list[0]

    # the output size is the hidden size of the main inputs
    main_input = inputs_list[0]
    original_shape = common_layers.shape_list(main_input)
    assert hparams.hidden_size == common_layers.shape_list(main_input)[-1]

    # all the inputs are in the same shape with main inputs
    for inputs in inputs_list:
        main_input.get_shape().assert_is_compatible_with(inputs.get_shape())

    def remove_pads(x):
        original_shape = common_layers.shape_list(x)
        # Collapse `x` across examples, and remove padding positions.
        x = tf.reshape(x, tf.concat([[-1], original_shape[2:]], axis=0))
        x = tf.expand_dims(pad_remover.remove(x), axis=0)
        return x

    if pad_remover:
        for i, inputs in enumerate(inputs_list):
            inputs_list[i] = remove_pads(inputs)

    ffn_inputs = (inputs_list[0] if len(inputs_list) == 1 else tf.concat(
        inputs_list, axis=-1))

    if ffn_layer_type == "dense":
        output = common_layers.dense(ffn_inputs,
                                     hparams.hidden_size,
                                     name=name,
                                     activation=activation,
                                     use_bias=True,
                                     kernel_initializer=kernel_initializer,
                                     bias_initializer=bias_initializer)

    elif ffn_layer_type == "dense_dropconnect":
        output = common_layers.dense_dropconnect(
            ffn_inputs,
            hparams.hidden_size,
            name=name,
            dropconnect_dropout=hparams.dropconnect_dropout,
            output_activation=activation)
        postprocess = False  # no dropout on the output unit

    elif ffn_layer_type == "dense_relu_dense":
        output = common_layers.dense_relu_dense(
            ffn_inputs,
            hparams.filter_size,
            hparams.hidden_size,
            name=name,
            dropout=hparams.relu_dropout,
            output_activation=activation,
        )

    else:
        raise ValueError("Unknown ffn_layer type: %s" % ffn_layer_type)

    if pad_remover:
        # Restore `output` to the original shape of `x`, including padding.
        output = tf.reshape(pad_remover.restore(tf.squeeze(output, axis=0)),
                            original_shape)

    if postprocess:
        if num_inputs == 1:
            output = common_layers.layer_postprocess(original_inputs, output,
                                                     hparams)
        else:  # only dropout (no residual)x
            hp = copy.copy(hparams)
            hp.layer_postprocess_sequence = hp.layer_postprocess_sequence.replace(
                "a", "")
            output = common_layers.layer_postprocess(original_inputs, output,
                                                     hp)

    return output