def get_embedding(self, num_embeddings, embedding_dim, padding_idx=None): """ Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_dim // 2 emb = layers.log(float(10000)) / (half_dim - -1) emb = layers.exp(layers.arange( start=0, end=half_dim, dtype='float32') * -emb) # [num_embeddings, embedding_dim // 2] emb = layers.unsqueeze(layers.arange(-num_embeddings // 2, num_embeddings // 2, dtype='float32'), axis=1) *\ layers.unsqueeze(emb, axis=0) emb = layers.concat([layers.sin(emb), layers.cos(emb)], dim=1) # [num_embeddings, embedding_dim] if embedding_dim % 2 == 1: emb = layers.concat( [emb, layers.zeros(shape=(num_embeddings, 1))], dim=1) if padding_idx is not None: emb[paddings_idx, :] = 0 self.origin_shift = num_embeddings // 2 return emb
def forward(self, tensor_list: NestedTensor): x = tensor_list.tensors h, w = x.shape[-2:] i = L.arange(0, w) j = L.arange(0, h) x_emb = self.col_embed(i) # [w, num_pos_feats] y_emb = self.row_embed(j) # [h, num_pos_feats] x_emb = L.expand(L.unsqueeze(x_emb, 0), (h, 1, 1)) # [h, w, num_pos_feats] y_emb = L.expand(L.unsqueeze(y_emb, 1), (1, w, 1)) # [h, w, num_pos_feats] pos = L.concat([x_emb, y_emb], -1) # [h, w, num_pos_feats * 2] pos = L.transpose(pos, perm=(2, 0, 1)) # [num_pos_feats * 2, h, w] pos = L.unsqueeze(pos, 0) # [1, num_pos_feats * 2, h, w] pos = L.expand( pos, (x.shape[0], 1, 1, 1)) # [batch_size, num_pos_feats * 2, h, w] return pos
def seq_gather(seq, idxs): """seq是[None, seq_len, s_size]的格式, idxs是[None, 1]的格式, 在seq的第i个序列中选出第idxs[i]个向量, 最终输出[None, s_size]的向量。 """ idxs = layers.cast(idxs, dtype="int32") batch_idxs = layers.arange(0, seq.shape[0], dtype="int32") batch_idxs = layers.unsqueeze(batch_idxs, 1) idxs = layers.concat([batch_idxs, idxs], 1) return layers.gather_nd(seq, idxs)
def index_sample(x, index): """Select input value according to index Arags: input: input matrix index: index matrix Returns: output >>> input [ [1, 2, 3], [4, 5, 6] ] >>> index [ [1, 2], [0, 1] ] >>> index_sample(input, index) [ [2, 3], [4, 5] ] """ x_s = x.shape dim = len(index.shape) - 1 assert x_s[:dim] == index.shape[:dim] r_x = layers.reshape(x, shape=(-1, *x_s[dim:])) index = layers.reshape(index, shape=(len(r_x), -1, 1)) # generate arange index, shape like index arr_index = layers.arange(start=0, end=len(index), dtype=index.dtype) arr_index = layers.unsqueeze(arr_index, axes=[1, 2]) arr_index = layers.expand_as(arr_index, index) # genrate new index new_index = layers.concat((arr_index, index), -1) new_index = layers.reshape(new_index, (-1, 2)) # get output out = layers.gather_nd(r_x, new_index) out = layers.reshape(out, (*x_s[:dim], -1)) return out
def _transpose_shift(E): """ -3 -2 -1 0 1 2 -30 -20 -10 00 10 20 -300 -200 -100 000 100 200 to 0 -10 -200 1 00 -100 2 10 000 :param E: batch_size x n_head x max_len x 2max_len :return: batch_size x n_head x max_len x max_len """ bsz, n_head, max_len, _ = E.size() zero_pad = layers.zeros(shape=(bsz, n_head, max_len, 1)) E = layers.reshape(x=layers.concat([E, zero_pad], axis=-1), shape=(bsz, n_head, -1, max_len)) indice = layers.arange(start=0, end=max_len, dtype=int) E = layers.index_select(input=E, index=indice, dim=-2) E = layers.transpose(E, perm=[0, 1, 3, 2]) return E
def reorder_head(layer, idx): n, a = layer.n_head, layer.d_key index = L.reshape(L.index_select(L.reshape(L.arange(0, n * a, dtype='int64'), shape=[n, a]), idx, dim=0), shape=[-1]) def reorder_head_matrix(linearLayer, index, dim=1): W = L.index_select(linearLayer.weight, index, dim=dim).detach() if linearLayer.bias is not None: if dim == 0: b = L.assign(linearLayer.bias).detach() else: b = L.assign(L.index_select(linearLayer.bias, index, dim=0)).detach() linearLayer.weight.stop_gradient = True linearLayer.weight.set_value(W) linearLayer.weight.stop_gradient = False if linearLayer.bias is not None: linearLayer.bias.stop_gradient = True linearLayer.bias.set_value(b) linearLayer.bias.stop_gradient = False reorder_head_matrix(layer.q.fn if hasattr(layer.q, 'fn') else layer.q, index) reorder_head_matrix(layer.k.fn if hasattr(layer.k, 'fn') else layer.k, index) reorder_head_matrix(layer.v.fn if hasattr(layer.v, 'fn') else layer.v, index) reorder_head_matrix(layer.o.fn if hasattr(layer.o, 'fn') else layer.o, index, dim=0)
def SinusoidalEmbedding(self, input): """ This function produces sinusoidal positional embeddings of any length. Padding symbols are ignored. Args: input: shaped like [bsz, seq_len]. embedding_dim: dimension for each position. padding_idx: init_size: """ bsz, seq_len = input.shape max_pos = self.padding_idx + seq_len if max_len > self.origin_shift: self.weights = self.get_embedding( max_pos * 2, self.embedding_dim, self.padding_idx ) positions = layers.arange(-seq_len, seq_len, dtype='long') + self.origin_shift embed = layers.index_select(input=self.weights, index=positions, dim=0) return emb
def position_id(x, r=0): pid = layers.arange(0, x.shape[1], dtype="int32") pid = layers.unsqueeze(pid, 0) r = layers.cast(layers.ones_like(x), dtype="int32") * r return layers.cast(layers.abs(layers.elementwise_sub(pid, r)), dtype='int64')
def topk_pool(gw, score, graph_id, ratio): """Implementation of topk pooling, where k means pooling ratio. Args: gw: Graph wrapper object. score: The attention score of all nodes, which is used to select important nodes. graph_id: The graphs that the nodes belong to. ratio: The pooling ratio of nodes we want to select. Return: perm: The index of nodes we choose. ratio_length: The selected node numbers of each graph. """ graph_lod = gw.graph_lod graph_nodes = gw.num_nodes num_graph = gw.num_graph num_nodes = L.ones(shape=[graph_nodes], dtype="float32") num_nodes = L.lod_reset(num_nodes, graph_lod) num_nodes_per_graph = L.sequence_pool(num_nodes, pool_type='sum') max_num_nodes = L.reduce_max(num_nodes_per_graph, dim=0) max_num_nodes = L.cast(max_num_nodes, dtype="int32") index = L.arange(0, gw.num_nodes, dtype="int64") offset = L.gather(graph_lod, graph_id, overwrite=False) index = (index - offset) + (graph_id * max_num_nodes) index.stop_gradient = True # padding dense_score = L.fill_constant(shape=[num_graph * max_num_nodes], dtype="float32", value=-999999) index = L.reshape(index, shape=[-1]) dense_score = L.scatter(dense_score, index, updates=score) num_graph = L.cast(num_graph, dtype="int32") dense_score = L.reshape(dense_score, shape=[num_graph, max_num_nodes]) # record the sorted index _, sort_index = L.argsort(dense_score, axis=-1, descending=True) # recover the index range graph_lod = graph_lod[:-1] graph_lod = L.reshape(graph_lod, shape=[-1, 1]) graph_lod = L.cast(graph_lod, dtype="int64") sort_index = L.elementwise_add(sort_index, graph_lod, axis=-1) sort_index = L.reshape(sort_index, shape=[-1, 1]) # use sequence_slice to choose selected node index pad_lod = L.arange(0, (num_graph + 1) * max_num_nodes, step=max_num_nodes, dtype="int32") sort_index = L.lod_reset(sort_index, pad_lod) ratio_length = L.ceil(num_nodes_per_graph * ratio) ratio_length = L.cast(ratio_length, dtype="int64") ratio_length = L.reshape(ratio_length, shape=[-1, 1]) offset = L.zeros(shape=[num_graph, 1], dtype="int64") choose_index = L.sequence_slice(input=sort_index, offset=offset, length=ratio_length) perm = L.reshape(choose_index, shape=[-1]) return perm, ratio_length