def __init__(self, conf, name="multi_head_attention"): """ Inits the module. Args: name: The module name. """ super(MultiHeadAttentionResidual, self).__init__(name=name) self.training = True self.conf = conf self.training = True with self._enter_variable_scope(): self._query_layer = snt.Linear( output_size=self.conf.head_nbr * self.conf.query_dim, use_bias=False, initializers={'w': utils.initializer(conf.embedding_dim)}, name="query_computer") self._value_layer = snt.Linear( output_size=self.conf.head_nbr * self.conf.value_dim, use_bias=False, initializers={'w': utils.initializer(conf.embedding_dim)}, name="value_computer") self._key_layer = snt.Linear( output_size=self.conf.head_nbr * self.conf.key_dim, use_bias=False, initializers={'w': utils.initializer(conf.embedding_dim)}, name="key_computer") self._graph_mha = modules.SelfAttention("graph_self_attention") self._glimpse_l = snt.Linear( name='glimpse_linear', output_size=self.conf.embedding_dim, use_bias=False, initializers={'w': utils.initializer(conf.embedding_dim)})
def test_self_attention(self): # Just one feature per node. values_np = np.arange(sum(self.N_NODE)) + 1. # Multiple heads, one positive values, one negative values. values_np = np.stack([values_np, values_np*-1.], axis=-1) # Multiple features per node, per head, at different scales. values_np = np.stack([values_np, values_np*0.1], axis=-1) values = tf.constant(values_np, dtype=tf.float32) keys_np = [ [[0.3, 0.4]]*2, # Irrelevant (only sender to one node) [[0.1, 0.5]]*2, # Not used (is not a sender) [[1, 0], [0, 1]], [[0, 1], [1, 0]], [[1, 1], [1, 1]], [[0.4, 0.3]]*2, # Not used (is not a sender) [[0.3, 0.2]]*2] # Not used (is not a sender) keys = tf.constant(keys_np, dtype=tf.float32) queries_np = [ [[0.2, 0.7]]*2, # Not used (is not a receiver) [[0.3, 0.2]]*2, # Irrelevant (only receives from one node) [[0.2, 0.8]]*2, # Not used (is not a receiver) [[0.2, 0.4]]*2, # Not used (is not a receiver) [[0.3, 0.9]]*2, # Not used (is not a receiver) [[0, np.log(2)], [np.log(3), 0]], [[np.log(2), 0], [0, np.log(3)]]] queries = tf.constant(queries_np, dtype=tf.float32) attention_graph = graphs.GraphsTuple( nodes=None, edges=None, globals=None, receivers=tf.constant(self.RECEIVERS, dtype=tf.int32), senders=tf.constant(self.SENDERS, dtype=tf.int32), n_node=tf.constant(self.N_NODE, dtype=tf.int32), n_edge=tf.constant(self.N_EDGE, dtype=tf.int32),) self_attention = modules.SelfAttention() output_graph = self_attention(values, keys, queries, attention_graph) mixed_nodes = output_graph.nodes with self.test_session() as sess: mixed_nodes_output = sess.run(mixed_nodes) expected_mixed_nodes = [ [[0., 0.], [0., 0.]], # Does not receive any edges [[1., 0.1], [-1., -0.1]], # Only receives from n0. [[0., 0.], [0., 0.]], # Does not receive any edges [[0., 0.], [0., 0.]], # Does not receive any edges [[0., 0.], [0., 0.]], # Does not receive any edges [[11/3, 11/3*0.1], # Head one, receives from n2(1/3) n3(2/3) [-15/4, -15/4*0.1]], # Head two, receives from n2(1/4) n3(3/4) [[20/5, 20/5*0.1], # Head one, receives from n2(2/5) n3(1/5) n4(2/5) [-28/7, -28/7*0.1]], # Head two, receives from n2(3/7) n3(1/7) n4(3/7) ] self.assertAllClose(expected_mixed_nodes, mixed_nodes_output)
def __init__(self, FLAGS, init=None, activ=None, name="GraphFormer"): super().__init__(FLAGS=FLAGS, init=init, activ=activ, name=name) self.hidden_size = self.FLAGS['tf_hidden_size'] with self._enter_variable_scope(): self._input_dense = snt.Linear(self.hidden_size, use_bias=False, name='input', initializers=self.init) self._q_dense = snt.Linear(self.hidden_size, use_bias=False, name='q', initializers=self.init) self._k_dense = snt.Linear(self.hidden_size, use_bias=False, name='k', initializers=self.init) self._v_dense = snt.Linear(self.hidden_size, use_bias=False, name='v', initializers=self.init) self._output_dense = snt.Linear(self.hidden_size, use_bias=False, name='output', initializers=self.init) self._sa = modules.SelfAttention() self._sa_laynorm = snt.LayerNorm() self._ff = snt.nets.MLP([self.hidden_size, self.hidden_size], activation=self.activ, initializers=self.init, activate_final=True) self._ff_laynorm = snt.LayerNorm() self._doub_dense = snt.Linear(2 * self.hidden_size, use_bias=False, name='doub', initializers=self.init)
def _build_graph(self, graphs_tuple, regularizer): """Builds graph network. Args: graphs_tuple: A GraphTuple instance. regularizer: Regularizer to be used in linear layers. Returns: output_graphs_tuple: A updated GraphTuple instance. """ node_values = graphs_tuple.nodes # Check configuations. num_heads = self.options.n_head key_dims = self.options.key_dims value_dims = node_values.shape[-1].value assert key_dims % num_heads == 0 assert value_dims % num_heads == 0 key_size = key_dims // num_heads value_size = value_dims // num_heads # Compute the key/query tensors shared across layers. if self.options.n_layer: with tf.variable_scope('self_attention'): node_queries = slim.fully_connected(node_values, num_outputs=key_dims, activation_fn=None, biases_initializer=None, scope='node_queries') node_keys = slim.fully_connected(node_values, num_outputs=key_dims, activation_fn=None, biases_initializer=None, scope='node_keys') # Initial RNN states. rnn_nodes = snt.GRU(hidden_size=value_dims) node_states = rnn_nodes.initial_state( batch_size=tf.shape(graphs_tuple.nodes)[0]) # Stack layers. self_attention = modules.SelfAttention() graphs_tuple = graphs_tuple.replace(nodes=None, edges=None) for _ in range(self.options.n_layer): _, node_states = rnn_nodes(node_values, node_states) # Call graph_nets SelfAttention model. graphs_tuple = self_attention( tf.reshape(node_values, [-1, num_heads, value_size]), tf.reshape(node_keys, [-1, num_heads, key_size]), tf.reshape(node_queries, [-1, num_heads, key_size]), graphs_tuple) node_values = tf.reshape(graphs_tuple.nodes, [-1, value_dims]) graphs_tuple = graphs_tuple.replace(nodes=None, edges=None) # Pack results, FC layer to project states. _, node_states = rnn_nodes(node_values, node_states) node_states = slim.fully_connected(node_states, num_outputs=value_dims, activation_fn=None, scope='node_states') graphs_tuple = graphs_tuple.replace(nodes=node_states, edges=None) return graphs_tuple