Пример #1
0
    def __init__(self, conf, name="multi_head_attention"):
        """ Inits the module.

               Args:
                   name: The module name.
               """
        super(MultiHeadAttentionResidual, self).__init__(name=name)
        self.training = True
        self.conf = conf
        self.training = True
        with self._enter_variable_scope():
            self._query_layer = snt.Linear(
                output_size=self.conf.head_nbr * self.conf.query_dim,
                use_bias=False,
                initializers={'w': utils.initializer(conf.embedding_dim)},
                name="query_computer")
            self._value_layer = snt.Linear(
                output_size=self.conf.head_nbr * self.conf.value_dim,
                use_bias=False,
                initializers={'w': utils.initializer(conf.embedding_dim)},
                name="value_computer")
            self._key_layer = snt.Linear(
                output_size=self.conf.head_nbr * self.conf.key_dim,
                use_bias=False,
                initializers={'w': utils.initializer(conf.embedding_dim)},
                name="key_computer")
            self._graph_mha = modules.SelfAttention("graph_self_attention")
            self._glimpse_l = snt.Linear(
                name='glimpse_linear',
                output_size=self.conf.embedding_dim,
                use_bias=False,
                initializers={'w': utils.initializer(conf.embedding_dim)})
Пример #2
0
  def test_self_attention(self):
    # Just one feature per node.
    values_np = np.arange(sum(self.N_NODE)) + 1.
    # Multiple heads, one positive values, one negative values.
    values_np = np.stack([values_np, values_np*-1.], axis=-1)
    # Multiple features per node, per head, at different scales.
    values_np = np.stack([values_np, values_np*0.1], axis=-1)
    values = tf.constant(values_np, dtype=tf.float32)

    keys_np = [
        [[0.3, 0.4]]*2,  # Irrelevant (only sender to one node)
        [[0.1, 0.5]]*2,  # Not used (is not a sender)
        [[1, 0], [0, 1]],
        [[0, 1], [1, 0]],
        [[1, 1], [1, 1]],
        [[0.4, 0.3]]*2,  # Not used (is not a sender)
        [[0.3, 0.2]]*2]  # Not used (is not a sender)
    keys = tf.constant(keys_np, dtype=tf.float32)

    queries_np = [
        [[0.2, 0.7]]*2,  # Not used (is not a receiver)
        [[0.3, 0.2]]*2,  # Irrelevant (only receives from one node)
        [[0.2, 0.8]]*2,  # Not used (is not a receiver)
        [[0.2, 0.4]]*2,  # Not used (is not a receiver)
        [[0.3, 0.9]]*2,  # Not used (is not a receiver)
        [[0, np.log(2)], [np.log(3), 0]],
        [[np.log(2), 0], [0, np.log(3)]]]
    queries = tf.constant(queries_np, dtype=tf.float32)

    attention_graph = graphs.GraphsTuple(
        nodes=None,
        edges=None,
        globals=None,
        receivers=tf.constant(self.RECEIVERS, dtype=tf.int32),
        senders=tf.constant(self.SENDERS, dtype=tf.int32),
        n_node=tf.constant(self.N_NODE, dtype=tf.int32),
        n_edge=tf.constant(self.N_EDGE, dtype=tf.int32),)

    self_attention = modules.SelfAttention()
    output_graph = self_attention(values, keys, queries, attention_graph)
    mixed_nodes = output_graph.nodes

    with self.test_session() as sess:
      mixed_nodes_output = sess.run(mixed_nodes)

    expected_mixed_nodes = [
        [[0., 0.], [0., 0.]],  # Does not receive any edges
        [[1., 0.1], [-1., -0.1]],  # Only receives from n0.
        [[0., 0.], [0., 0.]],  # Does not receive any edges
        [[0., 0.], [0., 0.]],  # Does not receive any edges
        [[0., 0.], [0., 0.]],  # Does not receive any edges
        [[11/3, 11/3*0.1],  # Head one, receives from n2(1/3) n3(2/3)
         [-15/4, -15/4*0.1]],  # Head two, receives from n2(1/4) n3(3/4)
        [[20/5, 20/5*0.1],   # Head one, receives from n2(2/5) n3(1/5) n4(2/5)
         [-28/7, -28/7*0.1]],  # Head two, receives from n2(3/7) n3(1/7) n4(3/7)
    ]

    self.assertAllClose(expected_mixed_nodes, mixed_nodes_output)
    def __init__(self, FLAGS, init=None, activ=None, name="GraphFormer"):
        super().__init__(FLAGS=FLAGS, init=init, activ=activ, name=name)
        self.hidden_size = self.FLAGS['tf_hidden_size']
        with self._enter_variable_scope():
            self._input_dense = snt.Linear(self.hidden_size,
                                           use_bias=False,
                                           name='input',
                                           initializers=self.init)
            self._q_dense = snt.Linear(self.hidden_size,
                                       use_bias=False,
                                       name='q',
                                       initializers=self.init)
            self._k_dense = snt.Linear(self.hidden_size,
                                       use_bias=False,
                                       name='k',
                                       initializers=self.init)
            self._v_dense = snt.Linear(self.hidden_size,
                                       use_bias=False,
                                       name='v',
                                       initializers=self.init)
            self._output_dense = snt.Linear(self.hidden_size,
                                            use_bias=False,
                                            name='output',
                                            initializers=self.init)

            self._sa = modules.SelfAttention()
            self._sa_laynorm = snt.LayerNorm()
            self._ff = snt.nets.MLP([self.hidden_size, self.hidden_size],
                                    activation=self.activ,
                                    initializers=self.init,
                                    activate_final=True)
            self._ff_laynorm = snt.LayerNorm()
            self._doub_dense = snt.Linear(2 * self.hidden_size,
                                          use_bias=False,
                                          name='doub',
                                          initializers=self.init)
Пример #4
0
    def _build_graph(self, graphs_tuple, regularizer):
        """Builds graph network.

    Args:
      graphs_tuple: A GraphTuple instance.
      regularizer: Regularizer to be used in linear layers.

    Returns:
      output_graphs_tuple: A updated GraphTuple instance.
    """
        node_values = graphs_tuple.nodes

        # Check configuations.
        num_heads = self.options.n_head
        key_dims = self.options.key_dims
        value_dims = node_values.shape[-1].value
        assert key_dims % num_heads == 0
        assert value_dims % num_heads == 0

        key_size = key_dims // num_heads
        value_size = value_dims // num_heads

        # Compute the key/query tensors shared across layers.
        if self.options.n_layer:
            with tf.variable_scope('self_attention'):
                node_queries = slim.fully_connected(node_values,
                                                    num_outputs=key_dims,
                                                    activation_fn=None,
                                                    biases_initializer=None,
                                                    scope='node_queries')
                node_keys = slim.fully_connected(node_values,
                                                 num_outputs=key_dims,
                                                 activation_fn=None,
                                                 biases_initializer=None,
                                                 scope='node_keys')

        # Initial RNN states.
        rnn_nodes = snt.GRU(hidden_size=value_dims)
        node_states = rnn_nodes.initial_state(
            batch_size=tf.shape(graphs_tuple.nodes)[0])

        # Stack layers.
        self_attention = modules.SelfAttention()
        graphs_tuple = graphs_tuple.replace(nodes=None, edges=None)

        for _ in range(self.options.n_layer):
            _, node_states = rnn_nodes(node_values, node_states)

            # Call graph_nets SelfAttention model.
            graphs_tuple = self_attention(
                tf.reshape(node_values, [-1, num_heads, value_size]),
                tf.reshape(node_keys, [-1, num_heads, key_size]),
                tf.reshape(node_queries, [-1, num_heads, key_size]),
                graphs_tuple)

            node_values = tf.reshape(graphs_tuple.nodes, [-1, value_dims])
            graphs_tuple = graphs_tuple.replace(nodes=None, edges=None)

        # Pack results, FC layer to project states.
        _, node_states = rnn_nodes(node_values, node_states)

        node_states = slim.fully_connected(node_states,
                                           num_outputs=value_dims,
                                           activation_fn=None,
                                           scope='node_states')

        graphs_tuple = graphs_tuple.replace(nodes=node_states, edges=None)
        return graphs_tuple