def forward(self, x, states): """ Parameters ---------- x - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) states The previous states - layout = 'NT' Shape (num_layers, 2, batch_size, prev_len, C_in)] - layout = 'TN' Shape (num_layers, 2, prev_len, batch_size, C_in)] Returns ------- new_x Output - layout = 'NT' Shape (batch_size, seq_length, C_out) - layout = 'TN' Shape (seq_length, batch_size, C_out) new_states The new states - layout = 'NT' Shape (num_layers, 2, batch_size, prev_len + seq_length, C_in) - layout = 'TN' Shape (num_layers, 2, prev_len + seq_length, batch_size, C_in) """ prev_len = npx.shape_array(states)[3] if self._layout == 'NT' else \ npx.shape_array(states)[2] x = self.get_initial_embedding(x, prev_len) if self._layout != self._compute_layout: x = np.swapaxes(x, 0, 1) states = np.swapaxes(states, 2, 3) new_states = [] for layer_idx in range(self._num_layers): layer_states = None if states is None else states[layer_idx] x, new_layer_states = self._layers[layer_idx](x, layer_states) new_states.append(new_layer_states) new_states = np.stack(new_states, axis=0) x = self._final_ln(x) if self._layout != self._compute_layout: x = np.swapaxes(x, 0, 1) new_states = np.swapaxes(new_states, 2, 3) return x, new_states
def forward(self, inputs, token_types, valid_length=None): # pylint: disable=arguments-differ """Generate the representation given the inputs. This is used in training or fine-tuning a Albert model. Parameters ---------- F inputs - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) token_types - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. valid_length : The valid length of each sequence Shape (batch_size,) Returns ------- contextual_embedding - layout = 'NT' Shape (batch_size, seq_length, units) - layout = 'TN' Shape (seq_length, batch_size, units) pooled_output This is optional. Shape (batch_size, units) """ initial_embedding = self.get_initial_embedding(inputs, token_types) # Projecting the embedding into units prev_out = initial_embedding if self.embed_size != self.units: prev_out = self.embed_factorized_proj(prev_out) outputs = [] if self._compute_layout != self._layout: # Swap input to reflect the compute_layout contextual_embeddings, additional_outputs = self.encoder( np.swapaxes(prev_out, 0, 1), valid_length) contextual_embeddings = np.swapaxes(contextual_embeddings, 0, 1) else: contextual_embeddings, additional_outputs = self.encoder( prev_out, valid_length) outputs.append(contextual_embeddings) if self.use_pooler: pooled_out = self.apply_pooling(contextual_embeddings) outputs.append(pooled_out) return tuple(outputs) if len(outputs) > 1 else outputs[0]
def forward(self, inputs, token_types, valid_length): # pylint: disable=arguments-differ """Generate the representation given the inputs. This is used in training or fine-tuning a mobile bert model. Parameters ---------- inputs - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) token_types If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) valid_length The valid length of each sequence Shape (batch_size,) Returns ------- contextual_embedding : Shape (batch_size, seq_length, units). pooled_output : This is optional. Shape (batch_size, units) """ embedding = self.get_initial_embedding(inputs, token_types) if self._compute_layout != self._layout: contextual_embeddings, additional_outputs = self.encoder(np.swapaxes(embedding, 0, 1), valid_length) contextual_embeddings = np.swapaxes(contextual_embeddings, 0, 1) else: contextual_embeddings, additional_outputs = self.encoder(embedding, valid_length) if self.use_pooler: pooled_out = self.apply_pooling(contextual_embeddings) return contextual_embeddings, pooled_out else: return contextual_embeddings
def forward(self, tokens, valid_length): embedding = self.get_initial_embedding(tokens) if self._layout != self._compute_layout: contextual_embeddings, additional_outputs = self.encoder( np.swapaxes(embedding, 0, 1), valid_length) contextual_embeddings = np.swapaxes(contextual_embeddings, 0, 1) else: contextual_embeddings, additional_outputs = self.encoder( embedding, valid_length) if self.use_pooler: if isinstance(contextual_embeddings, list): pooled_out = self.apply_pooling(contextual_embeddings[-1]) else: pooled_out = self.apply_pooling(contextual_embeddings) return contextual_embeddings, pooled_out else: return contextual_embeddings
def forward(self, inputs, token_types, valid_length, masked_positions): """Getting the scores of the masked positions. Parameters ---------- inputs - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) token_types The type of the token. For example, if the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) valid_length The valid length of each sequence Shape (batch_size,) masked_positions The masked position of the sequence Shape (batch_size, num_masked_positions). Returns ------- contextual_embedding - layout = 'NT' Shape (batch_size, seq_length, units). - layout = 'TN' Shape (seq_length, batch_size, units). pooled_out Shape (batch_size, units) mlm_scores Shape (batch_size, num_masked_positions, vocab_size) """ contextual_embeddings, pooled_out = self.backbone_model(inputs, token_types, valid_length) if self.backbone_model.layout == 'TN': mlm_features = select_vectors_by_position(np.swapaxes(contextual_embeddings, 0, 1), masked_positions) else: mlm_features = select_vectors_by_position(contextual_embeddings, masked_positions) intermediate_output = self.mlm_decoder(mlm_features) if self.backbone_model.embed_size != self.backbone_model.units: scores = self.embedding_table( intermediate_output[:, :, :self.backbone_model.embed_size]) extra_scores = self.extra_table( intermediate_output[:, :, self.backbone_model.embed_size:]) mlm_scores = scores + extra_scores else: mlm_scores = self.embedding_table(intermediate_output) return contextual_embeddings, pooled_out, mlm_scores
def forward(self, inputs, token_types, valid_length, masked_positions): """Generate the representation given the inputs. This is used in training or fine-tuning a Albert model. Parameters ---------- inputs - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) token_types Type of the tokens. If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) valid_length The valid length of each sequence Shape (batch_size,) masked_positions The masked position of the sequence Shape (batch_size, num_masked_positions). Returns ------- contextual_embedding - layout = 'NT' Shape (batch_size, seq_length, units). - layout = 'TN' Shape (seq_length, batch_size, units). pooled_out Shape (batch_size, units) sop_score : Shape (batch_size, 2) mlm_scores : Shape (batch_size, num_masked_positions, vocab_size) """ contextual_embeddings, pooled_out = self.backbone_model(inputs, token_types, valid_length) sop_score = self.sop_classifier(pooled_out) if self.layout == 'NT': mlm_features = select_vectors_by_position(contextual_embeddings, masked_positions) else: mlm_features = select_vectors_by_position(np.swapaxes(contextual_embeddings, 0, 1), masked_positions) mlm_scores = self.mlm_decoder(mlm_features) return contextual_embeddings, pooled_out, sop_score, mlm_scores
def test_t5_model(cfg_key, activation, ctx): with ctx: cfg = T5Model.get_cfg(cfg_key) cfg.defrost() cfg.MODEL.vocab_size = 256 cfg.MODEL.d_model = 128 cfg.MODEL.d_ff = 512 cfg.MODEL.num_layers = 2 cfg.MODEL.num_heads = 4 cfg.MODEL.activation = activation cfg.MODEL.layout = 'NT' cfg.freeze() cfg_tn = cfg.clone() cfg_tn.defrost() cfg_tn.MODEL.layout = 'TN' cfg_tn.freeze() # test TN and NT consistency t5_model = T5Model.from_cfg(cfg) t5_model.initialize() t5_model.hybridize() t5_model_tn = T5Model.from_cfg(cfg_tn) t5_model_tn.share_parameters(t5_model.collect_params()) t5_model_tn.hybridize() batch_size = 8 src_length = 32 tgt_length = 18 src_data = np.random.randint(0, 255, (batch_size, src_length)) src_valid_length = np.random.randint(src_length // 2, src_length, (batch_size, )) tgt_data = np.random.randint(0, 255, (batch_size, tgt_length)) tgt_valid_length = np.random.randint(tgt_length // 4, tgt_length, (batch_size, )) out = t5_model(src_data, src_valid_length, tgt_data, tgt_valid_length) out_tn = t5_model_tn(src_data.T, src_valid_length, tgt_data.T, tgt_valid_length) assert np.allclose(np.swapaxes(out, 0, 1), out_tn, 1E-5, 1E-5) # test consistency with various target valid length for shift in range(1, np.min(tgt_valid_length).item()): for partial_out in [ t5_model(src_data, src_valid_length, tgt_data[:, :-shift], tgt_valid_length - shift), t5_model(src_data, src_valid_length, tgt_data, tgt_valid_length - shift) ]: for i in range(batch_size): vl = tgt_valid_length[i].item() - shift assert np.allclose(partial_out[i, :vl], out[i, :vl], 1E-5, 1E-5)
def forward(self, inputs, token_types, valid_length, masked_positions): """Getting the scores of the masked positions. Parameters ---------- F inputs - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) token_types - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. valid_length : The valid length of each sequence Shape (batch_size,) masked_positions : The masked position of the sequence Shape (batch_size, num_masked_positions). Returns ------- contextual_embedding - layout = 'NT' Shape (batch_size, seq_length, units). - layout = 'TN' Shape (seq_length, batch_size, units). pooled_out Shape (batch_size, units) mlm_scores : Shape (batch_size, num_masked_positions, vocab_size) """ contextual_embeddings, pooled_out = self.backbone_model( inputs, token_types, valid_length) if self.backbone_model.layout == 'NT': mlm_features = select_vectors_by_position(contextual_embeddings, masked_positions) else: mlm_features = select_vectors_by_position( np.swapaxes(contextual_embeddings, 0, 1), masked_positions) mlm_scores = self.mlm_decoder(mlm_features) return contextual_embeddings, pooled_out, mlm_scores
def forward(self, inputs, valid_length, masked_positions): """Getting the scores of the masked positions. Parameters ---------- inputs - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) valid_length The valid length of each sequence Shape (batch_size,) masked_positions The masked position of the sequence Shape (batch_size, num_masked_positions). Returns ------- contextual_embedding - layout = 'NT' Shape (batch_size, seq_length, units). - layout = 'TN' Shape (seq_length, batch_size, units). pooled_out Shape (batch_size, units) mlm_scores : Shape (batch_size, num_masked_positions, vocab_size) """ all_encodings_outputs, pooled_out = self.backbone_model( inputs, valid_length) if self.backbone_model._output_all_encodings: contextual_embeddings = all_encodings_outputs[-1] else: contextual_embeddings = all_encodings_outputs if self.backbone_model.layout == 'TN': contextual_embeddings = np.swapaxes(contextual_embeddings, 0, 1) mlm_features = select_vectors_by_position(contextual_embeddings, masked_positions) mlm_scores = self.mlm_decoder(mlm_features) return all_encodings_outputs, pooled_out, mlm_scores
def forward(self, rel_positions, query=None): """Forward function Parameters ---------- rel_positions The relative shifts. Shape (query_length, mem_length). Each element represents the shift between the :math:`i-th` element of query and the :math:`j-th` element of memory. query The query for computing the relative scores. The shape depends on the layout. If we use T5 attention, the query will not be used. Returns ------- rel_scores The relative attention scores Can have shape (batch_size, num_heads, query_length, mem_length) or (num_heads, query_length, mem_length) """ if self._method == 'transformer_xl' or self._method == 'shaw': assert query is not None, 'Must specify query if method={}'.format(self._method) if self._bidirectional: if self._max_distance is not None: rel_positions = np.clip(rel_positions, a_min=-self._max_distance, a_max=self._max_distance) else: if self._max_distance is not None: rel_positions = np.clip(rel_positions, a_min=0, a_max=self._max_distance) # uniq_rel.shape = (#uniq,), rev_index.shape = (L_q, L_m) uniq_rel, rev_index = np.unique(rel_positions, return_inverse=True) uniq_rel_pos_embed = self._rel_pos_embed(uniq_rel) if self._method == 'transformer_xl': uniq_rel_pos_embed = self._rel_proj(self._dropout_layer(uniq_rel_pos_embed)) # Shape (#uniq, K, C_q) uniq_rel_pos_embed = npx.reshape(uniq_rel_pos_embed, (-2, self._num_heads, self._head_query_units)) # Calculate the dot-product between query and the relative positional embeddings. # After the calculation, rel_score.shape = (L_q, #uniq, N, K) if self._layout == 'NKT': # query_for_rel: (N, K, L_q, C_q) if self._use_einsum: rel_score = np.einsum('bnid,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(query, np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) elif self._layout == 'NTK': # query_for_rel: (N, L_q, K, C_q) if self._use_einsum: rel_score = np.einsum('bind,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(np.swapaxes(query, 1, 2), np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) elif self._layout == 'TNK': # query_for_rel: (L_q, N, K, C_q) if self._use_einsum: rel_score = np.einsum('ibnd,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(np.transpose(query, (1, 2, 0, 3)), np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) else: raise NotImplementedError # We use gather_nd to select the elements # TODO(sxjscience) Use advanced indexing once available rev_index = npx.reshape_like(rev_index, rel_positions).astype(np.int32) query_idx = np.expand_dims(npx.arange_like(rel_positions, axis=0).astype(np.int32), axis=-1) + np.zeros_like(rev_index) rel_score = npx.gather_nd(rel_score, np.stack([query_idx, rev_index])) rel_score = np.transpose(rel_score, (2, 3, 0, 1)) elif self._method == 't5': # shape is (K, L_q, L_m) rel_score = self._rel_pos_embed(rel_positions).transpose((2, 0, 1)) else: raise NotImplementedError return rel_score
def multi_head_dot_attn(query, key, value, mask=None, edge_scores=None, dropout: float = 0.0, scaled: bool = True, normalized: bool = False, eps: float = 1E-6, query_head_units: Optional[int] = None, layout: str = 'NKT', use_einsum: bool = False): """Multihead dot product attention between the query, key, value. scaled is False, normalized is False: D(h_q, h_k) = <h_q, h_k> scaled is True, normalized is False: D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q) scaled is False, normalized is True: D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> scaled is True, normalized is True: D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> / sqrt(dim_q) If edge_scores is provided, we will calcualte the attention as scores = D(h_q, h_k) + EdgeScore_{q, k} Parameters ---------- query Query. The shape depends on the layout - layout is 'NKT' Shape (batch_size, num_heads, query_length, key_dim) - layout is 'NTK' Shape (batch_size, query_length, num_heads, key_dim) - layout is 'TNK' Shape (query_length, batch_size, num_heads, key_dim) key Key. The shape depends on the layout - layout is 'NKT' Shape (batch_size, num_heads, mem_length, key_dim) - layout is 'NTK' Shape (batch_size, mem_length, num_heads, key_dim) - layout is 'TNK' Shape (mem_length, batch_size, num_heads, key_dim) value Value. The shape depends on the layout - layout is 'NKT' Shape (batch_size, num_heads, mem_length, value_dim) - layout is 'NTK' Shape (batch_size, mem_length, num_heads, value_dim) - layout is 'TNK' Shape (mem_length, batch_size, num_heads, value_dim) mask Mask between query and memory. Shape (batch_size, query_length, mem_length) edge_scores The edge attention score. Shape can be any shape that is broadcastable to (batch_size, num_heads, query_length, mem_length) dropout Dropout rate scaled Whether to divide the attention weights by the sqrt of the query dimension. This is first proposed in "[NIPS2017] Attention is all you need.":: .. code-block:: none score = <h_q, h_k> / sqrt(dim_q) normalized If turned on, the cosine distance is used, i.e:: .. code-block:: none score = <h_q / ||h_q||, h_k / ||h_k||> eps The epsilon value used in L2 normalization query_head_units The units of each query head. If it's empty, we will estimate it via the shape_array of the query. layout This stands for the layout of the attention cell. The shape of the input/output will depend on the layout. Currently, we support 'NKT', 'NTK' and 'TNK' in which 'N' means the batch_size, 'K' means the head, and 'T' means the length dimension. use_einsum Whether to use einsum for the computation Returns ------- context_vec - layout is 'NKT' or 'NTK' Shape (batch_size, query_length, num_heads * value_units) - layout is 'TNK' Shape (query_length, batch_size, num_heads * value_units) additional_info scores: Shape (batch_size, num_head, query_length, mem_length) attn_weight: Shape (batch_size, num_head, query_length, mem_length) """ # TODO(sxjscience) Profile layout if normalized: query = l2_normalize(query, axis=-1, eps=eps) key = l2_normalize(key, axis=-1, eps=eps) if scaled: if query_head_units is None: raise NotImplementedError('You will need to specify query_head_units!') else: scale = math.sqrt(query_head_units) else: scale = None if layout == 'NKT': # 1. Expand the dimension of the mask: # (B, L_query, L_mem) --> (B, 1, L_query, L_mem) if mask is not None: mask = np.expand_dims(mask, axis=1).astype(np.bool) # 2. Calculate the attention weights # Score: (B, N, L_query, C_Q) X (B, N, L_mem, C_Q) --> (B, N, L_query, L_mem) scores = npx.batch_dot(query, key, transpose_b=True) if edge_scores is not None: scores = scores + edge_scores attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale) attn_weights = npx.dropout(attn_weights, p=dropout) # 3. Calculate the context vector # (B, N, L_query, L_mem) X (B, N, L_mem, C_V) --> (B, L_query, N * C_V) if use_einsum: context_vec = np.einsum('bnij,bnjc->binc', attn_weights, value) else: context_vec = npx.batch_dot(attn_weights, value).transpose((0, 2, 1, 3)) context_vec = npx.reshape(context_vec, (-2, -2, -1)) elif layout == 'NTK': # 1. Expand the dimension of the mask: # (B, L_query, L_mem) --> (B, 1, L_query, L_mem) if mask is not None: mask = np.expand_dims(mask, axis=1).astype(np.bool) # 2. Calculate the attention weights # Score: (B, L_query, N, C_Q) X (B, L_mem, N, C_Q) --> (B, N, L_query, L_mem) if use_einsum: scores = np.einsum('binc,bjnc->bnij', query, key) else: scores = npx.batch_dot(np.swapaxes(query, 1, 2), np.swapaxes(key, 1, 2), transpose_b=True) if edge_scores is not None: scores = scores + edge_scores attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale) attn_weights = npx.dropout(attn_weights, p=dropout) # 3. Calculate the context vector # (B, N, L_query, L_mem) X (B, L_mem, N, C_V) --> (B, L_query, N * C_V) if use_einsum: context_vec = np.einsum('bnij,bjnc->binc', attn_weights, value) else: context_vec = npx.batch_dot(attn_weights, np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3)) context_vec = npx.reshape(context_vec, (-2, -2, -1)) elif layout == 'TNK': # 1. Expand the dimension of the mask: # (B, L_query, L_mem) --> (B, 1, L_query, L_mem) if mask is not None: mask = np.expand_dims(mask, axis=1).astype(np.bool) # 2. Calculate the attention weights # Score: (L_query, B, N, C_Q) X (L_mem, B, N, C_Q) --> (B, N, L_query, L_mem) # This layout structure can be implemented very efficiently because B, N are consecutive # to each other. To have a clear picture of what's happening, we may consider the # (i, j)th element of the output # out[i, j, :, :] = query[:, i, j, :] X key[:, i, j, :].T, which is just one GEMM call # We can thus implement the whole kernel via a single call of batched GEMM with stride. if use_einsum: scores = np.einsum('ibnc,jbnc->bnij', query, key) else: scores = npx.batch_dot(query.transpose((1, 2, 0, 3)), key.transpose((1, 2, 3, 0))) if edge_scores is not None: scores = scores + edge_scores attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale) attn_weights = npx.dropout(attn_weights, p=dropout) # 3. Calculate the context vector # (B, N, L_query, L_mem) X (L_mem, B, N, C_V) --> (L_query, B, N * C_V) # Again, we can implement it via a single call to batched GEMM with stride. # Shape (B, N, L_query, C_V) if use_einsum: context_vec = np.einsum('bnij,jbnc->ibnc', attn_weights, value) else: context_vec = npx.batch_dot(attn_weights, value.transpose((1, 2, 0, 3))).transpose((2, 0, 1, 3)) context_vec = npx.reshape(context_vec, (-2, -2, -1)) else: raise NotImplementedError('layout="{}" is not supported! ' 'We only support layout = "NKT", "NTK", and "TNK".' .format(layout)) return context_vec, [scores, attn_weights]
def forward(self, inputs, token_types, valid_length=None): # pylint: disable=arguments-differ """Generate the representation given the inputs. This is used in training or fine-tuning a Electra model. Parameters ---------- F inputs - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) token_types - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. valid_length The valid length of each sequence Shape (batch_size,) Returns ------- contextual_embedding - layout = 'NT' Shape (batch_size, seq_length, units). - layout = 'TN' Shape (seq_length, batch_size, units). pooled_output This is optional. Shape (batch_size, units) """ initial_embedding = self.get_initial_embedding(inputs, token_types) # Projecting the embedding into units prev_out = initial_embedding if self.embed_size != self.units: prev_out = self.embed_factorized_proj(prev_out) outputs = [] if self._compute_layout != self._layout: # Swap the axes if the compute_layout and layout mismatch contextual_embeddings, additional_outputs = self.encoder(np.swapaxes(prev_out, 0, 1), valid_length) contextual_embeddings = np.swapaxes(contextual_embeddings, 0, 1) else: contextual_embeddings, additional_outputs = self.encoder(prev_out, valid_length) outputs.append(contextual_embeddings) if self.use_pooler: # Here we just get the first token ([CLS]) without any pooling strategy, # which is slightly different from bert model with the pooled_out # the attribute name is keeping the same as bert and albert model with defualt # use_pooler=True if self._layout == 'NT': pooled_out = contextual_embeddings[:, 0, :] else: pooled_out = contextual_embeddings[0, :, :] outputs.append(pooled_out) return tuple(outputs) if len(outputs) > 1 else outputs[0]
def forward(self, inputs, token_types, valid_length, masked_positions): """Generate the representation given the inputs. This is used in training or fine-tuning a mobile mobile bert model. Parameters ---------- F inputs - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) token_types - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. valid_length The valid length of each sequence Shape (batch_size,) masked_positions The masked position of the sequence Shape (batch_size, num_masked_positions). Returns ------- contextual_embedding - layout = 'NT' Shape (batch_size, seq_length, units). - layout = 'TN' Shape (seq_length, batch_size, units). pooled_out Shape (batch_size, units) nsp_score Shape (batch_size, 2) mlm_scores Shape (batch_size, num_masked_positions, vocab_size) """ contextual_embeddings, pooled_out = self.backbone_model( inputs, token_types, valid_length) nsp_score = self.nsp_classifier(pooled_out) if self.backbone_model.layout == 'NT': mlm_features = select_vectors_by_position(contextual_embeddings, masked_positions) else: mlm_features = select_vectors_by_position( np.swapaxes(contextual_embeddings, 0, 1), masked_positions) intermediate_output = self.mlm_decoder(mlm_features) if self.backbone_model.embed_size != self.backbone_model.units: scores = self.embedding_table( intermediate_output[:, :, :self.backbone_model.embed_size]) extra_scores = self.extra_table( intermediate_output[:, :, self.backbone_model.embed_size:]) mlm_scores = scores + extra_scores else: mlm_scores = self.embedding_table(intermediate_output) return contextual_embeddings, pooled_out, nsp_score, mlm_scores