def forward(self, data: np.ndarray, prev: Optional[np.ndarray]) -> np.ndarray: """ Apply processing sequence to data with optional previous input. :param data: Input data. Shape: (batch, length, num_hidden). :param prev: Previous data. Shape: (batch, length, num_hidden). :return: Processed data. Shape: (batch, length, num_hidden). """ if not self.sequence: return data if prev is None: assert 'r' not in self.sequence, "Residual connection not allowed if no previous value given." for step in self.sequence: if step == "r": data = data + prev elif step == "n": data = self.layer_norm(data) elif step == "d": if self.dropout > 0.0: data = npx.dropout(data, p=self.dropout) else: raise ValueError("Unknown step in sequence: %s" % step) return data
def forward(self, queries: np.ndarray, key_values: np.ndarray, heads: np.ndarray, lengths: Optional[np.ndarray] = None, bias: Optional[np.ndarray] = None): # (n*h, lq, lk) logits = npx.interleaved_matmul_encdec_qk(queries, key_values, heads=heads) if bias is not None: logits = logits + bias if lengths is not None: # required shape for lengths: (n*h, lq); required dtype: int32 probs = npx.softmax(logits, axis=-1, length=lengths, use_length=True) else: probs = npx.softmax(logits, axis=-1) probs = npx.dropout(probs, p=self.dropout) if self.dropout > 0.0 else probs # key_values: (lk, n, dv * 2) # probs: (n*h, lq, lk) # result: (n, lq, dv) return npx.interleaved_matmul_encdec_valatt(key_values, probs, heads=heads)
def forward(self, x: np.ndarray) -> np.ndarray: h = self.ff1(x) h = self.act(h) if self.use_glu: h = h * self.linear(x) if self.dropout > 0.0: h = npx.dropout(h, p=self.dropout) y = self.ff2(h) return y
def forward(self, data, valid_length): # positional embedding data = self.pos_embedding(data, None) if self.config.dropout_prepost > 0.0: data = npx.dropout(data=data, p=self.config.dropout_prepost) # (batch_size * heads, seq_len) att_valid_length = layers.prepare_source_valid_lengths(valid_length, data, num_heads=self.config.attention_heads) data = np.transpose(data, axes=(1, 0, 2)) for block in self.layers: data = block(data, att_valid_length) data = self.final_process(data, None) data = np.transpose(data, axes=(1, 0, 2)) return data, valid_length
def forward(self, data, valid_length): # pylint: disable=arguments-differ # We will catch the optional factor weights in kwargs average_factors_embeds = [] # type: List[np.ndarray] concat_factors_embeds = [] # type: List[np.ndarray] sum_factors_embeds = [] # type: List[np.ndarray] if self.config.num_factors > 1 and self.config.factor_configs is not None: data, *data_factors = (np.squeeze(x, axis=2) for x in np.split(data, self.config.num_factors, axis=2)) for i, (factor_data, factor_config) in enumerate(zip(data_factors, self.config.factor_configs)): factor_weight = self.factor_weights[i] factor_embedding = npx.embedding(factor_data, input_dim=factor_config.vocab_size, weight=factor_weight.data(), output_dim=factor_config.num_embed) if factor_config.combine == C.FACTORS_COMBINE_CONCAT: concat_factors_embeds.append(factor_embedding) elif factor_config.combine == C.FACTORS_COMBINE_SUM: sum_factors_embeds.append(factor_embedding) elif factor_config.combine == C.FACTORS_COMBINE_AVERAGE: average_factors_embeds.append(factor_embedding) else: raise ValueError("Unknown combine value for factors: %s" % factor_config.combine) else: data = np.squeeze(data, axis=2) embed = npx.embedding(data, weight=self.weight.data(), input_dim=self.config.vocab_size, output_dim=self.config.num_embed, dtype=self._dtype, sparse_grad=False) if self.config.num_factors > 1 and self.config.factor_configs is not None: if average_factors_embeds: embed = npx.add_n(embed, *average_factors_embeds) / (len(average_factors_embeds) + 1) if sum_factors_embeds: embed = npx.add_n(embed, *sum_factors_embeds) if concat_factors_embeds: embed = np.concatenate((embed, *concat_factors_embeds), axis=2) if self.config.dropout > 0: embed = npx.dropout(data=embed, p=self.config.dropout) return embed, np.copy(valid_length) # See https://github.com/apache/incubator-mxnet/issues/14228
def multi_head_dot_attn(query, key, value, mask=None, edge_scores=None, dropout: float = 0.0, scaled: bool = True, normalized: bool = False, eps: float = 1E-6, query_head_units: Optional[int] = None, layout: str = 'NKT', use_einsum: bool = False): """Multihead dot product attention between the query, key, value. scaled is False, normalized is False: D(h_q, h_k) = <h_q, h_k> scaled is True, normalized is False: D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q) scaled is False, normalized is True: D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> scaled is True, normalized is True: D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> / sqrt(dim_q) If edge_scores is provided, we will calcualte the attention as scores = D(h_q, h_k) + EdgeScore_{q, k} Parameters ---------- query Query. The shape depends on the layout - layout is 'NKT' Shape (batch_size, num_heads, query_length, key_dim) - layout is 'NTK' Shape (batch_size, query_length, num_heads, key_dim) - layout is 'TNK' Shape (query_length, batch_size, num_heads, key_dim) key Key. The shape depends on the layout - layout is 'NKT' Shape (batch_size, num_heads, mem_length, key_dim) - layout is 'NTK' Shape (batch_size, mem_length, num_heads, key_dim) - layout is 'TNK' Shape (mem_length, batch_size, num_heads, key_dim) value Value. The shape depends on the layout - layout is 'NKT' Shape (batch_size, num_heads, mem_length, value_dim) - layout is 'NTK' Shape (batch_size, mem_length, num_heads, value_dim) - layout is 'TNK' Shape (mem_length, batch_size, num_heads, value_dim) mask Mask between query and memory. Shape (batch_size, query_length, mem_length) edge_scores The edge attention score. Shape can be any shape that is broadcastable to (batch_size, num_heads, query_length, mem_length) dropout Dropout rate scaled Whether to divide the attention weights by the sqrt of the query dimension. This is first proposed in "[NIPS2017] Attention is all you need.":: .. code-block:: none score = <h_q, h_k> / sqrt(dim_q) normalized If turned on, the cosine distance is used, i.e:: .. code-block:: none score = <h_q / ||h_q||, h_k / ||h_k||> eps The epsilon value used in L2 normalization query_head_units The units of each query head. If it's empty, we will estimate it via the shape_array of the query. layout This stands for the layout of the attention cell. The shape of the input/output will depend on the layout. Currently, we support 'NKT', 'NTK' and 'TNK' in which 'N' means the batch_size, 'K' means the head, and 'T' means the length dimension. use_einsum Whether to use einsum for the computation Returns ------- context_vec - layout is 'NKT' or 'NTK' Shape (batch_size, query_length, num_heads * value_units) - layout is 'TNK' Shape (query_length, batch_size, num_heads * value_units) additional_info scores: Shape (batch_size, num_head, query_length, mem_length) attn_weight: Shape (batch_size, num_head, query_length, mem_length) """ # TODO(sxjscience) Profile layout if normalized: query = l2_normalize(query, axis=-1, eps=eps) key = l2_normalize(key, axis=-1, eps=eps) if scaled: if query_head_units is None: raise NotImplementedError('You will need to specify query_head_units!') else: scale = math.sqrt(query_head_units) else: scale = None if layout == 'NKT': # 1. Expand the dimension of the mask: # (B, L_query, L_mem) --> (B, 1, L_query, L_mem) if mask is not None: mask = np.expand_dims(mask, axis=1).astype(np.bool) # 2. Calculate the attention weights # Score: (B, N, L_query, C_Q) X (B, N, L_mem, C_Q) --> (B, N, L_query, L_mem) scores = npx.batch_dot(query, key, transpose_b=True) if edge_scores is not None: scores = scores + edge_scores attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale) attn_weights = npx.dropout(attn_weights, p=dropout) # 3. Calculate the context vector # (B, N, L_query, L_mem) X (B, N, L_mem, C_V) --> (B, L_query, N * C_V) if use_einsum: context_vec = np.einsum('bnij,bnjc->binc', attn_weights, value) else: context_vec = npx.batch_dot(attn_weights, value).transpose((0, 2, 1, 3)) context_vec = npx.reshape(context_vec, (-2, -2, -1)) elif layout == 'NTK': # 1. Expand the dimension of the mask: # (B, L_query, L_mem) --> (B, 1, L_query, L_mem) if mask is not None: mask = np.expand_dims(mask, axis=1).astype(np.bool) # 2. Calculate the attention weights # Score: (B, L_query, N, C_Q) X (B, L_mem, N, C_Q) --> (B, N, L_query, L_mem) if use_einsum: scores = np.einsum('binc,bjnc->bnij', query, key) else: scores = npx.batch_dot(np.swapaxes(query, 1, 2), np.swapaxes(key, 1, 2), transpose_b=True) if edge_scores is not None: scores = scores + edge_scores attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale) attn_weights = npx.dropout(attn_weights, p=dropout) # 3. Calculate the context vector # (B, N, L_query, L_mem) X (B, L_mem, N, C_V) --> (B, L_query, N * C_V) if use_einsum: context_vec = np.einsum('bnij,bjnc->binc', attn_weights, value) else: context_vec = npx.batch_dot(attn_weights, np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3)) context_vec = npx.reshape(context_vec, (-2, -2, -1)) elif layout == 'TNK': # 1. Expand the dimension of the mask: # (B, L_query, L_mem) --> (B, 1, L_query, L_mem) if mask is not None: mask = np.expand_dims(mask, axis=1).astype(np.bool) # 2. Calculate the attention weights # Score: (L_query, B, N, C_Q) X (L_mem, B, N, C_Q) --> (B, N, L_query, L_mem) # This layout structure can be implemented very efficiently because B, N are consecutive # to each other. To have a clear picture of what's happening, we may consider the # (i, j)th element of the output # out[i, j, :, :] = query[:, i, j, :] X key[:, i, j, :].T, which is just one GEMM call # We can thus implement the whole kernel via a single call of batched GEMM with stride. if use_einsum: scores = np.einsum('ibnc,jbnc->bnij', query, key) else: scores = npx.batch_dot(query.transpose((1, 2, 0, 3)), key.transpose((1, 2, 3, 0))) if edge_scores is not None: scores = scores + edge_scores attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale) attn_weights = npx.dropout(attn_weights, p=dropout) # 3. Calculate the context vector # (B, N, L_query, L_mem) X (L_mem, B, N, C_V) --> (L_query, B, N * C_V) # Again, we can implement it via a single call to batched GEMM with stride. # Shape (B, N, L_query, C_V) if use_einsum: context_vec = np.einsum('bnij,jbnc->ibnc', attn_weights, value) else: context_vec = npx.batch_dot(attn_weights, value.transpose((1, 2, 0, 3))).transpose((2, 0, 1, 3)) context_vec = npx.reshape(context_vec, (-2, -2, -1)) else: raise NotImplementedError('layout="{}" is not supported! ' 'We only support layout = "NKT", "NTK", and "TNK".' .format(layout)) return context_vec, [scores, attn_weights]
def forward( self, step_input: np.ndarray, states: List[np.ndarray]) -> Tuple[np.ndarray, List[np.ndarray]]: mask = None if self.inference_only: steps, source_valid_length, *other = states source_encoded = None # use constant pre-computed key value projections from the states enc_att_kv = other[:self.config.num_layers] autoregr_states = other[self.config.num_layers:] else: if any(layer.needs_mask for layer in self.layers): mask = self.autoregressive_bias( step_input) # mask: (1, length, length) steps, source_encoded, source_valid_length, *autoregr_states = states enc_att_kv = [None for _ in range(self.config.num_layers)] if any(layer.num_state_tensors > 1 for layer in self.layers): # separates autoregressive states by layer states_iter = iter(autoregr_states) autoregr_states = [ list(islice(states_iter, 0, layer.num_state_tensors)) for layer in self.layers ] # (batch_size * heads, query_length) source_valid_length = layers.prepare_source_valid_lengths( source_valid_length, step_input, num_heads=self.config.attention_heads) # target: (batch_size, length, model_size) target = self.pos_embedding(step_input, steps) # (length, batch_size, model_size) target = np.transpose(target, axes=(1, 0, 2)) if self.config.dropout_prepost > 0.0: target = npx.dropout(data=target, p=self.config.dropout_prepost) new_autoregr_states = [] for layer, layer_autoregr_state, layer_enc_att_kv in zip( self.layers, autoregr_states, enc_att_kv): target, new_layer_autoregr_state = layer(target, mask, source_encoded, source_valid_length, layer_autoregr_state, layer_enc_att_kv) new_autoregr_states += [*new_layer_autoregr_state] target = self.final_process(target, None) target = np.transpose(target, axes=(1, 0, 2)) # Inference: increment steps by 1 (discarded in training) steps = steps + 1 if self.inference_only: # pass in cached encoder states encoder_attention_keys_values = states[2:2 + self.config.num_layers] new_states = [ steps, states[1] ] + encoder_attention_keys_values + new_autoregr_states else: encoder_outputs = states[1] encoder_valid_length = states[2] new_states = [steps, encoder_outputs, encoder_valid_length ] + new_autoregr_states return target, new_states