def gumbel_softmax(logits, temperature: float = 1.0, eps: float = 1E-10, hard=True, use_np_gumbel: bool = True): r"""Perform the gumbel-softmax trick to generate differentiable one-hot vectors from the input logits. Here, the gumbel distribution is Gumbel(\alpha) = -log (-log U) + \log \alpha, in which U is the uniform(0, 1) distribution. A nice property of Gumbel is: \argmax({Gumbel(\alpha_i)}) \sim multinomial(\alpha_i) The Gumbel-Softmax trick is to use the softmax + straight-through estimator to produce one-hot vectors that represent the sampling result. References: 1. https://en.wikipedia.org/wiki/Gumbel_distribution 2. [ICLR2017] Categorical Reparameterization with Gumbel-Softmax Parameters ---------- logits Logits. Shape (..., V) temperature The temperature that controls the eps The eps for stability of gradient hard Whether to use the straight-through estimator to produce one-hot vectors. use_np_gumbel Whether to use the random.gumble operator Returns ------- ret The returned output. Shape (..., V) """ # TODO(sxjscience) Investigate the impact of random.gumbel: # Actually, random.gumble has no eps and may have problem in calculating the gradient. if use_np_gumbel: gumbels = np.random.gumbel(np.zeros_like(logits)) else: u = np.random.uniform(np.zeros_like(logits), 1) gumbels = -np.log(-np.log(u + eps) + eps) y = npx.softmax((gumbels + logits) / temperature, axis=-1) if hard: y_hard = np.max(y, axis=-1, keepdims=True) == y y_hard = npx.stop_gradient(y_hard - y) + y return y_hard else: return y
def forward(self, hyp_lengths, reference_lengths): if self.weight == 0.0: if isinstance(hyp_lengths, (int, float)): return 0.0 else: # subtract to avoid MxNet's warning of not using both arguments # this branch should not and is not used during inference return np.zeros_like(hyp_lengths - reference_lengths) else: # log_bp is always <= 0.0 if isinstance(hyp_lengths, (int, float)): log_bp = min(0.0, 1.0 - reference_lengths / hyp_lengths) else: log_bp = np.minimum(np.zeros_like(hyp_lengths, dtype='float32'), 1.0 - reference_lengths / hyp_lengths) return self.weight * log_bp
def forward(self, logits, labels, length_ratio, source_length, target_length): """ :param logits: Model logits. Shape: (batch, length, vocab_size). :param labels: Gold targets. Shape: (batch, length). :param length_ratio: Length Ratios. Shape: (batch,). :param source_length: Source lengths. Shape: (batch,). :param target_length: Target lengths. Shape: (batch,). :return: Sequence scores. Shape: (batch,). """ logprobs = npx.log_softmax(logits, axis=-1, temperature=self.softmax_temperature) # Select the label probability, then take their logs. # probs and scores: (batch_size, target_seq_len) token_scores = npx.pick(logprobs, labels, axis=-1) if self.score_type == C.SCORING_TYPE_NEGLOGPROB: token_scores = token_scores * -1 # Sum, then apply length penalty. The call to `np.where` masks out invalid values from scores. # zeros and sums: (batch_size,) scores = np.sum(np.where(labels != 0, token_scores, np.zeros_like(token_scores)), axis=1) if self.constant_length_ratio is not None and self.constant_length_ratio > 0.0: predicted_output_length = source_length * self.constant_length_ratio else: predicted_output_length = source_length * length_ratio scores = self.scorer(scores, target_length, predicted_output_length) return scores
def forward(self, scores, target_dists, finished, best_hyp_indices): """ Choose an extension of each hypothesis from its softmax distribution. :param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size) :param target_dists: The non-cumulative target distributions (ignored). :param finished: The list of finished hypotheses. :param best_hyp_indices: Best hypothesis indices constant. :return: The row indices, column indices, and values of the sampled words. """ # Map the negative logprobs to probabilities so as to have a distribution target_dists = np.exp(-target_dists) # n == 0 means sample from the full vocabulary. Otherwise, we sample from the top n. if self.n != 0: # select the top n in each row, via a mask masked_items = npx.topk(target_dists, k=self.n, ret_typ='mask', axis=1, is_ascend=False) # set unmasked items to 0 masked_items = np.where(masked_items, target_dists, masked_items) # renormalize target_dists = masked_items / np.sum(masked_items, axis=1, keepdims=True) # Sample from the target distributions over words, then get the corresponding values from the cumulative scores best_word_indices = npx.random.categorical(target_dists, get_prob=False) # Zeroes for finished hypotheses. best_word_indices = np.where(finished, np.zeros_like(best_word_indices), best_word_indices) values = npx.pick(scores, best_word_indices, axis=1, keepdims=True) best_hyp_indices = npx.slice_like(best_hyp_indices, best_word_indices, axes=(0,)) return best_hyp_indices, best_word_indices, values
def forward(self, positions): """ Parameters ---------- positions : NDArray Shape (..., ) Returns ------- ret : Shape (..., units) """ emb = np.expand_dims(positions.astype(self._dtype), axis=-1) * self.base_mult.data() sin_emb = np.sin(emb) cos_emb = np.cos(emb) if self._units % 2 == 0: return np.concatenate([sin_emb, cos_emb], axis=-1) else: return np.concatenate([ sin_emb, cos_emb, np.expand_dims(np.zeros_like(positions).astype(self._dtype), axis=-1) ], axis=-1)
def forward(self, hidden_states, valid_length, mem_states, mem_valid_length): # 1. relative position embeddings and attention masks position_embeddings = self.relative_position_encoder( self._get_relative_position(hidden_states)) # relative position embedding is not used for cross attention, # so we just obtain the correct shape and fill it with 0 mem_relative_position = np.zeros_like( self._get_relative_position(hidden_states, mem_states)) mem_position_embeddings = np.repeat(np.expand_dims( mem_relative_position, axis=0), self._num_heads, axis=0) self_attn_mask = gen_self_attn_mask(hidden_states, valid_length, dtype=self._dtype, attn_type='causal', layout=self.layout) mem_attn_mask = gen_mem_attn_mask(mem_states, mem_valid_length, hidden_states, valid_length, dtype=self._dtype, layout=self.layout) # 2. decoder blocks and other layers hidden_states = self.dropout(hidden_states) for layer in self.layers: hidden_states = layer(hidden_states, self_attn_mask, position_embeddings, mem_states, mem_attn_mask, mem_position_embeddings) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states
def trunc_gumbel(logits, truncation): """Sample from the TruncGumbel distribution. The cumulative density function (CDF) of the Truncated Gumbel distribution is defined as TruncGumbel(\alpha, truncation) \prop max(Gumbel(\alpha), truncation) To sample from the distribution, we can use the CDF inversion technique. References: 1. [NIPS2014] A* Sampling, https://papers.nips.cc/paper/5449-a-sampling.pdf 2. https://cmaddis.github.io/gumbel-machinery Parameters ---------- logits The logits. Shape (...,) truncation The truncation. Shape (...,) Returns ------- samples Samples from the TruncGumbel(logits, truncation) Shape (...,) """ gumbels = np.random.gumbel(np.zeros_like(logits)) + logits return -np.log(np.exp(-gumbels) + np.exp(-truncation))
def test_prevent_unk_update_scores(): pytest.importorskip("mxnet") from mxnet import np import sockeye.beam_search vocab_size = 10 batch_beam_size = 3 us = sockeye.beam_search.UpdateScores() pad_dist = np.full((batch_beam_size, vocab_size - 1), fill_value=np.inf, dtype='float32') eos_dist = np.full((batch_beam_size, vocab_size), fill_value=np.inf, dtype='float32') eos_dist[:, C.EOS_ID] = 0 unk_dist = np.zeros_like(eos_dist) unk_dist[:, C.UNK_ID] = np.inf # pylint: disable=E1137 lengths = np.array([[0], [1], [0]], dtype='int32') max_lengths = np.array([[1], [2], [3]], dtype='int32') # first on reaches max length scores_accumulated = np.ones((3, 1), dtype='float32') finished = np.array( [ [0], # not finished [1], # finished [0] ], # not finished dtype='int32') inactive = np.zeros_like(finished) target_dists = np.random.uniform(0, 1, (3, vocab_size)) scores, lengths = us(target_dists, finished, inactive, scores_accumulated, lengths, max_lengths, unk_dist, pad_dist, eos_dist) scores = scores lengths = lengths.reshape((-1, )) assert (lengths == np.array( [[1], [1], [1]])).all() # all lengths but finished updated + 1 assert (scores[0] == (1. + target_dists[0] + eos_dist)).all() # 1 reached max length, force eos assert (scores[1] == np.array([1.] + pad_dist[1].tolist()) ).all() # 2 finished, force pad, keep score assert scores[2, C.UNK_ID] == np.inf # 3 scores of <unk> should be np.inf assert (scores[2] == (1. + target_dists[2] + unk_dist[2])).all() # 3 scores + previous scores
def test_getitem_autograd(np_array, index): x = np.array(np_array, dtype=np_array.dtype) x.attach_grad() with autograd.record(): y = x[index] y.backward() value = np.ones_like(y) x_grad = np.zeros_like(x) x_grad[index] = value assert same(x_grad.asnumpy(), x.grad.asnumpy())
def forward(self, scores, lengths, reference_lengths): lp = self._lp(lengths) if self._bp is not None: bp = self._bp(lengths, reference_lengths) else: if isinstance(scores, (int, float)): bp = 0.0 else: # avoid warning for unused input bp = np.zeros_like(reference_lengths) if reference_lengths is not None else 0.0 return scores / lp - bp
def init_state_from_encoder( self, encoder_outputs: np.ndarray, encoder_valid_length: Optional[np.ndarray] = None, target_embed: Optional[np.ndarray] = None) -> List[np.ndarray]: """ Returns the initial states given encoder output. States for teacher-forced training are encoder outputs and a valid length mask for encoder outputs. At inference, this method returns the following state tuple: valid length bias, step state, [projected encoder attention keys, projected encoder attention values] * num_layers, [autoregressive state dummies] * num_layers. :param encoder_outputs: Encoder outputs. Shape: (batch, source_length, encoder_dim). :param encoder_valid_length: Valid lengths of encoder outputs. Shape: (batch,). :param target_embed: Target-side embedding layer output. Shape: (batch, target_length, target_embedding_dim). :return: Initial states. """ if target_embed is None: # Inference: initial step = 0. Shape: (batch_size, 1) steps = np.expand_dims(np.zeros_like(encoder_valid_length), axis=1) else: # Training: steps up to target length. Shape: (1, target_length) steps = np.expand_dims(npx.arange_like(target_embed, axis=1), axis=0) if self.inference_only: # Encoder projection caching, therefore we don't pass the encoder_outputs states = [steps, encoder_valid_length] for layer in self.layers: enc_att_kv = layer.enc_attention.ff_kv(encoder_outputs) states.append(np.transpose(enc_att_kv, axes=(1, 0, 2))) else: # NO encoder projection caching states = [ steps, np.transpose(encoder_outputs, axes=(1, 0, 2)), encoder_valid_length ] _batch_size = encoder_outputs.shape[0] _ctx = encoder_outputs.ctx _dtype = encoder_outputs.dtype dummy_autoregr_states = [ np.zeros(layer.get_states_shape(_batch_size), ctx=_ctx, dtype=_dtype) for layer in self.layers for _ in range(layer.num_state_tensors) ] states += dummy_autoregr_states return states
def score_batch(self, batch: data_io.Batch) -> np.ndarray: batch = batch.split_and_load(ctx=self.context) batch_scores = [] # type: List[np.ndarray] for inputs, labels in batch.shards(): source, source_length, target, target_length = inputs outputs = self.model(*inputs) # type: Dict[str, np.ndarray] logits = outputs[C.LOGITS_NAME] # type: np.ndarray label = labels[C.TARGET_LABEL_NAME] length_ratio = outputs.get(C.LENRATIO_NAME, np.zeros_like(source_length)) scores = self.batch_scorer(logits, label, length_ratio, source_length, target_length) batch_scores.append(scores) # shape: (batch_size,). return np.concatenate(batch_scores, axis=0)
def test_crop_backward(test_nd_arr, TestCase): a_np = test_nd_arr.asnumpy() b_np = a_np[(slice(TestCase.y, TestCase.y + TestCase.height), slice(TestCase.x, TestCase.x + TestCase.width), slice(0, 3))] data = mx.sym.Variable('data') crop_sym = mx.sym.image.crop(data, TestCase.x, TestCase.y, TestCase.width, TestCase.height) expected_in_grad = np.zeros_like(np.array(a_np)) expected_in_grad[(slice(TestCase.y, TestCase.y + TestCase.height), slice(TestCase.x, TestCase.x + TestCase.width), slice(0, 3))] = b_np check_symbolic_backward(crop_sym, [a_np], [b_np], [expected_in_grad])
def get_initial_embedding(self, inputs, token_types=None): """Get the initial token embeddings that considers the token type and positional embeddings Parameters ---------- inputs - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) token_types The type of tokens. If None, it will be initialized as all zero. - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) Returns ------- embedding The initial embedding that will be fed into the encoder - layout = 'NT' Shape (batch_size, seq_length, C_embed) - layout = 'TN' Shape (seq_length, batch_size, C_embed) """ if self.layout == 'NT': time_axis, batch_axis = 1, 0 else: time_axis, batch_axis = 0, 1 embedding = self.word_embed(inputs) if token_types is None: token_types = np.zeros_like(inputs) type_embedding = self.token_type_embed(token_types) embedding = embedding + type_embedding if self.pos_embed_type is not None: positional_embedding = self.token_pos_embed(npx.arange_like(inputs, axis=time_axis)) positional_embedding = np.expand_dims(positional_embedding, axis=batch_axis) embedding = embedding + positional_embedding # Extra layer normalization plus dropout embedding = self.embed_layer_norm(embedding) embedding = self.embed_dropout(embedding) return embedding
def backward(self, req, out_grad, in_data, out_data, in_grad, aux): if ReluOp.guided_backprop: # Get output and gradients of output y = out_data[0] dy = out_grad[0] # Zero out the negatives in the gradients of the output dy_positives = np.maximum(dy, np.zeros_like(dy)) # What output values were greater than 0? y_ones = y.__gt__(0) # Mask out the values for which at least one of dy or y is negative dx = dy_positives * y_ones self.assign(in_grad[0], req[0], dx) else: # Regular backward for ReLU x = in_data[0] x_gt_zero = x.__gt__(0) dx = out_grad[0] * x_gt_zero self.assign(in_grad[0], req[0], dx)
def select_vectors_by_position(data, positions): """Select each batch with the given positions. Once advanced indexing can be hybridized, we can revise the implementation. out[i, j, ...] = data[i, positions[i, j], ...] Parameters ---------- data Input tensor of contextualized token embeddings Shape (batch_size, seq_length, ...) positions Input tensor of the positions. Shape (batch_size, num_sel_positions). For each sample in the batch, the values in this tensor must not exceed the length of the sequence. Returns ------- out The selection result. Shape (batch_size, num_sel_positions, ...) """ # Here, we use gather_nd to select the output from data: # Need to compute # out[i, j, :] = in[i, masked_position[i, j], :] # Thus, construct a indices with shape [2, batch_size, num_masked_position], where # indices[0, i, j] = i # indices[1, i, j] = masked_position[i, j] # Then, out = gather_nd(in, indices) positions = positions.astype(np.int32) # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...] batch_idx = np.expand_dims(npx.arange_like(positions, axis=0), axis=1).astype(np.int32) batch_idx = batch_idx + np.zeros_like(positions) indices = np.stack([batch_idx, positions]) # TODO(sxjscience) We can revise the implementation to advanced indexing # once the bug in MXNet is solved: # https://github.com/apache/incubator-mxnet/issues/18919 out = npx.gather_nd(data, indices) return out
def add_vectors_by_position(data, increment, positions): """Scatter each batch with the given positions. data[i, positions[i, j], ...] += increment[i, j, ...] Parameters ---------- F data Input tensor of the array to be updated. Shape (batch_size, seq_length, ...) increment Input tensor of token ids Shape (batch_size, num_disp_position, ...) positions Input tensor of the positions. Shape (batch_size, num_disp_position). For each sample in the batch, the values in this tensor must not exceed the length of the sequence. Returns ------- out The updated result. Shape (batch_size, seq_length, ...) """ # Here, we use index_add to disperse the output from data: # Need to compute # out[i, masked_position[i, j], :] = in[i, j, :] # Thus, construct an indices with shape [2, batch_size * num_masked_position], where # indices[0, i * num_masked_position + j] = i # indices[1, i * num_masked_position + j] = masked_position[i, j] # And convert data to the shape of the (batch_size * num_masked_position, ) # Then, out = npx.index_add(data, indices, increment) positions = positions.astype(np.int32) # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...] batch_idx = np.expand_dims(npx.arange_like(positions, axis=0), axis=1).astype(np.int32) batch_idx = batch_idx + np.zeros_like(positions) indices = np.stack([batch_idx.reshape((-1, )), positions.reshape((-1, ))]) out = npx.index_add(data, indices, npx.reshape(increment, (-5, -4))) return out
def update_vectors_by_position(data, val, positions): """ Update each batch with the given positions. Considered as a reversed process of "select_vectors_by_position", this is an operator similar to "add_vectors_by_position" that updates the results instead of adding. data[i, positions[i, j], :] = val[i, j, :] Parameters ---------- F data: Input tensor of the array to be updated. Shape (batch_size, seq_length) val Input tensor of token ids Shape (batch_size, num_disp_position) positions Input tensor of the positions. Shape (batch_size, num_disp_position). For each sample in the batch, the values in this tensor must not exceed the length of the sequence. Returns ------- out The updated result. Shape (batch_size, seq_length) """ positions = positions.astype(np.int32) # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...] batch_idx = np.expand_dims(npx.arange_like(positions, axis=0), axis=1).astype(np.int32) batch_idx = batch_idx + np.zeros_like(positions) indices = np.stack([batch_idx.reshape((-1, )), positions.reshape((-1, ))]) out = npx.index_update(data, indices, npx.reshape(val, (-5, -4))) return out
def test_zeros_like(): inp = np.ones((INT_OVERFLOW, 2)) out = np.zeros_like(inp) assert out.shape == inp.shape assert out[0, 0] == 0 and out[-1, -1] == 0
def forward(self, rel_positions, query=None): """Forward function Parameters ---------- rel_positions The relative shifts. Shape (query_length, mem_length). Each element represents the shift between the :math:`i-th` element of query and the :math:`j-th` element of memory. query The query for computing the relative scores. The shape depends on the layout. If we use T5 attention, the query will not be used. Returns ------- rel_scores The relative attention scores Can have shape (batch_size, num_heads, query_length, mem_length) or (num_heads, query_length, mem_length) """ if self._method == 'transformer_xl' or self._method == 'shaw': assert query is not None, 'Must specify query if method={}'.format(self._method) if self._bidirectional: if self._max_distance is not None: rel_positions = np.clip(rel_positions, a_min=-self._max_distance, a_max=self._max_distance) else: if self._max_distance is not None: rel_positions = np.clip(rel_positions, a_min=0, a_max=self._max_distance) # uniq_rel.shape = (#uniq,), rev_index.shape = (L_q, L_m) uniq_rel, rev_index = np.unique(rel_positions, return_inverse=True) uniq_rel_pos_embed = self._rel_pos_embed(uniq_rel) if self._method == 'transformer_xl': uniq_rel_pos_embed = self._rel_proj(self._dropout_layer(uniq_rel_pos_embed)) # Shape (#uniq, K, C_q) uniq_rel_pos_embed = npx.reshape(uniq_rel_pos_embed, (-2, self._num_heads, self._head_query_units)) # Calculate the dot-product between query and the relative positional embeddings. # After the calculation, rel_score.shape = (L_q, #uniq, N, K) if self._layout == 'NKT': # query_for_rel: (N, K, L_q, C_q) if self._use_einsum: rel_score = np.einsum('bnid,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(query, np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) elif self._layout == 'NTK': # query_for_rel: (N, L_q, K, C_q) if self._use_einsum: rel_score = np.einsum('bind,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(np.swapaxes(query, 1, 2), np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) elif self._layout == 'TNK': # query_for_rel: (L_q, N, K, C_q) if self._use_einsum: rel_score = np.einsum('ibnd,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(np.transpose(query, (1, 2, 0, 3)), np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) else: raise NotImplementedError # We use gather_nd to select the elements # TODO(sxjscience) Use advanced indexing once available rev_index = npx.reshape_like(rev_index, rel_positions).astype(np.int32) query_idx = np.expand_dims(npx.arange_like(rel_positions, axis=0).astype(np.int32), axis=-1) + np.zeros_like(rev_index) rel_score = npx.gather_nd(rel_score, np.stack([query_idx, rev_index])) rel_score = np.transpose(rel_score, (2, 3, 0, 1)) elif self._method == 't5': # shape is (K, L_q, L_m) rel_score = self._rel_pos_embed(rel_positions).transpose((2, 0, 1)) else: raise NotImplementedError return rel_score
def get_corrupted_tokens(self, inputs, original_tokens, masked_positions, logits): """ Sample from the generator to create corrupted input. Parameters ---------- F inputs The masked input - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) original_tokens The original tokens that appear in the unmasked input sequence Shape (batch_size, num_masked_positions). masked_positions The masked position of the sequence Shape (batch_size, num_masked_positions). logits The logits of each tokens Shape (batch_size, num_masked_positions, vocab_size) Returns ------- corrupted_tokens Shape (batch_size, ) fake_data - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) labels - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) """ if self._disallow_correct: # TODO(sxjscience), Revise the implementation disallow = npx.one_hot(masked_positions, depth=self.vocab_size, dtype=self._dtype) logits = logits - 1000.0 * disallow # gumbel_softmax() samples from the logits with a noise of Gumbel distribution prob = gumbel_softmax( F, logits, temperature=self._temperature, eps=self._gumbel_eps, use_np_gumbel=False) corrupted_tokens = np.argmax(prob, axis=-1).astype(np.int32) if self.disc_backbone.layout == 'TN': inputs = inputs.T original_data = update_vectors_by_position(F, inputs, original_tokens, masked_positions) fake_data = update_vectors_by_position(F, inputs, corrupted_tokens, masked_positions) updates_mask = add_vectors_by_position(np.zeros_like(inputs), np.ones_like(masked_positions), masked_positions) # Dealing with multiple zeros in masked_positions which # results in a non-zero value in the first index [CLS] updates_mask = np.minimum(updates_mask, 1) labels = updates_mask * np.not_equal(fake_data, original_data) if self.disc_backbone.layout == 'TN': return corrupted_tokens, fake_data.T, labels.T else: return corrupted_tokens, fake_data, labels
def test_mx_pt_eq_sockeye_model(): pytest.importorskip('mxnet') from mxnet import np import sockeye.transformer import sockeye.encoder import sockeye.model # model setup source_vocab_size = target_vocab_size = 32000 num_embed_source = num_embed_target = model_size = 512 max_seq_len_source = max_seq_len_target = 100 num_source_factors = 1 num_target_factors = 1 num_layers = 4 weight_tying = False batch_size = 4 topk_size = 200 config_encoder = sockeye.transformer.TransformerConfig( model_size=model_size, attention_heads=8, feed_forward_num_hidden=256, act_type='relu', num_layers=num_layers, dropout_attention=0, dropout_act=0, dropout_prepost=0, positional_embedding_type=C.FIXED_POSITIONAL_EMBEDDING, preprocess_sequence='n', postprocess_sequence='r', max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, use_lhuc=False) config_encoder_pt = sockeye.transformer_pt.TransformerConfig( model_size=model_size, attention_heads=8, feed_forward_num_hidden=256, act_type='relu', num_layers=num_layers, dropout_attention=0, dropout_act=0, dropout_prepost=0, positional_embedding_type=C.FIXED_POSITIONAL_EMBEDDING, preprocess_sequence='n', postprocess_sequence='r', max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, use_lhuc=False) config_decoder = sockeye.transformer.TransformerConfig( model_size=model_size, attention_heads=8, feed_forward_num_hidden=256, act_type='relu', num_layers=num_layers, dropout_attention=0, dropout_act=0, dropout_prepost=0, positional_embedding_type=C.FIXED_POSITIONAL_EMBEDDING, preprocess_sequence='n', postprocess_sequence='r', max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, depth_key_value=model_size, use_lhuc=False) config_decoder_pt = sockeye.transformer_pt.TransformerConfig( model_size=model_size, attention_heads=8, feed_forward_num_hidden=256, act_type='relu', num_layers=num_layers, dropout_attention=0, dropout_act=0, dropout_prepost=0, positional_embedding_type=C.FIXED_POSITIONAL_EMBEDDING, preprocess_sequence='n', postprocess_sequence='r', max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, depth_key_value=model_size, use_lhuc=False) config_embed_source = sockeye.encoder.EmbeddingConfig( vocab_size=source_vocab_size, num_embed=num_embed_source, dropout=0, factor_configs=None, allow_sparse_grad=False) config_embed_target = sockeye.encoder.EmbeddingConfig( vocab_size=target_vocab_size, num_embed=num_embed_target, dropout=0, factor_configs=None, allow_sparse_grad=False) data_statistics = sockeye.data_io_pt.DataStatistics( num_sents=0, num_discarded=0, num_tokens_source=0, num_tokens_target=0, num_unks_source=0, num_unks_target=0, max_observed_len_source=100, max_observed_len_target=100, size_vocab_source=source_vocab_size, size_vocab_target=target_vocab_size, length_ratio_mean=1.0, length_ratio_std=0.001, buckets=[], num_sents_per_bucket=[], average_len_target_per_bucket=[], length_ratio_stats_per_bucket=None) data_config = sockeye.data_io_pt.DataConfig( data_statistics=data_statistics, max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, num_source_factors=num_source_factors, num_target_factors=num_target_factors) config_length_task = None model_config = sockeye.model.ModelConfig( config_data=data_config, vocab_source_size=source_vocab_size, vocab_target_size=target_vocab_size, config_embed_source=config_embed_source, config_embed_target=config_embed_target, config_encoder=config_encoder, config_decoder=config_decoder, config_length_task=config_length_task, weight_tying_type=C.WEIGHT_TYING_NONE, lhuc=False, dtype=C.DTYPE_FP32) model_config_pt = sockeye.model.ModelConfig( config_data=data_config, vocab_source_size=source_vocab_size, vocab_target_size=target_vocab_size, config_embed_source=config_embed_source, config_embed_target=config_embed_target, config_encoder=config_encoder_pt, config_decoder=config_decoder_pt, config_length_task=config_length_task, weight_tying_type=C.WEIGHT_TYING_NONE, lhuc=False, dtype=C.DTYPE_FP32) # inputs source_inputs_mx = np.random.randint( 0, max_seq_len_source, (batch_size, max_seq_len_source, num_source_factors)) source_input_lengths_mx = np.random.randint(0, max_seq_len_source, (batch_size, )) target_inputs_mx = np.random.randint( 0, max_seq_len_target, (batch_size, max_seq_len_target, num_source_factors)) target_input_lengths_mx = np.random.randint(0, max_seq_len_target, (batch_size, )) source_inputs_pt = pt.tensor(source_inputs_mx.asnumpy()) source_input_lengths_pt = pt.tensor(source_input_lengths_mx.asnumpy()) target_inputs_pt = pt.tensor(target_inputs_mx.asnumpy()) target_input_lengths_pt = pt.tensor(target_input_lengths_mx.asnumpy()) step_inputs_mx = np.random.randint(0, target_vocab_size, (batch_size, num_target_factors)) vocab_slice_ids_mx = np.random.randint(0, target_vocab_size, (topk_size, )) step_inputs_pt = pt.tensor(step_inputs_mx.asnumpy()) vocab_slice_ids_pt = pt.tensor(vocab_slice_ids_mx.asnumpy()) b_mx = sockeye.model.SockeyeModel(model_config, inference_only=False, mc_dropout=False, forward_pass_cache_size=0) b_mx.initialize() b_pt = sockeye.model_pt.PyTorchSockeyeModel(model_config_pt, inference_only=False, mc_dropout=False, forward_pass_cache_size=0) assert b_mx.state_structure() == b_pt.state_structure() # test forward() # first run mx block to complete deferred initialization forward_dict_mx = b_mx(source_inputs_mx, source_input_lengths_mx, target_inputs_mx, target_input_lengths_mx) # get weights from mx into pt b_pt.weights_from_mxnet_block(b_mx) forward_dict_pt = b_pt(source_inputs_pt, source_input_lengths_pt, target_inputs_pt, target_input_lengths_pt) assert forward_dict_mx.keys() == forward_dict_pt.keys() logits_mx = forward_dict_mx[C.LOGITS_NAME].asnumpy() logits_pt = forward_dict_pt[C.LOGITS_NAME].detach().numpy() assert np.allclose(logits_mx, logits_pt, atol=1e-05) # test encode() source_encoded_mx, source_encoded_length_mx = b_mx.encode( source_inputs_mx, source_input_lengths_mx) source_encoded_pt, source_encoded_length_pt = b_pt.encode( source_inputs_pt, source_input_lengths_pt) assert np.allclose(source_encoded_mx.asnumpy(), source_encoded_pt.detach().numpy(), atol=1e-05) assert np.allclose(source_encoded_length_mx.asnumpy(), source_encoded_length_pt.detach().numpy(), atol=1e-05) # test encode_and_initialize() init_states_mx, pred_out_length_mx = b_mx.encode_and_initialize( source_inputs_mx, source_input_lengths_mx, constant_length_ratio=0.0) init_states_pt, pred_out_length_pt = b_pt.encode_and_initialize( source_inputs_pt, source_input_lengths_pt, constant_length_ratio=0.0) if config_length_task is None: assert np.allclose(pred_out_length_mx.asnumpy(), np.zeros_like(source_input_lengths_mx).asnumpy()) assert np.allclose( pred_out_length_pt.detach().numpy(), pt.zeros_like(source_input_lengths_pt).detach().numpy()) else: assert pred_out_length_mx.asnumpy() == pred_out_length_pt.detach( ).numpy() assert len(init_states_mx) == len(init_states_pt) state_structure = b_pt.decoder.state_structure() for s_mx, s_pt, structure in zip(init_states_mx, init_states_pt, state_structure): if structure != C.MASK_STATE: # MASK state is new in Pytorch and not equivalent assert np.allclose(s_mx.asnumpy(), s_pt.detach().numpy(), atol=1e-05) # test decode_step() b_pt.eval() states_mx = init_states_mx states_pt = init_states_pt step_output_mx, states_mx, factor_outputs_mx = b_mx.decode_step( step_inputs_mx, states_mx, vocab_slice_ids=vocab_slice_ids_mx) step_output_pt, states_pt, factor_outputs_pt = b_pt.decode_step( step_inputs_pt, states_pt, vocab_slice_ids=vocab_slice_ids_pt) assert np.allclose(step_output_mx.asnumpy(), step_output_pt.detach().numpy(), atol=1e-05) assert step_output_mx.asnumpy().shape == step_output_pt.detach().numpy( ).shape == (batch_size, topk_size) assert len(factor_outputs_mx) == len(factor_outputs_pt) # TODO assert factor outputs equality assert len(states_mx) == len(states_pt) for s_mx, s_pt, structure in zip(states_mx, states_pt, state_structure): if structure != C.MASK_STATE: # MASK state is new in Pytorch and not equivalent assert np.allclose(s_mx.asnumpy(), s_pt.detach().numpy(), atol=1e-05) from pprint import pprint pprint(b_mx.collect_params()) for param_tensor in b_pt.state_dict(): print(param_tensor, "\t", b_pt.state_dict()[param_tensor].size()) # save & load parameters with TemporaryDirectory() as work_dir: fname = os.path.join(work_dir, 'params.pt') b_pt.save_parameters(fname) b_pt.load_parameters(fname) forward_dict_pt = b_pt(source_inputs_pt, source_input_lengths_pt, target_inputs_pt, target_input_lengths_pt) assert forward_dict_mx.keys() == forward_dict_pt.keys() logits_mx = forward_dict_mx[C.LOGITS_NAME].asnumpy() logits_pt = forward_dict_pt[C.LOGITS_NAME].detach().numpy() assert np.allclose(logits_mx, logits_pt, atol=1e-05)
def forward(self, is_train, req, in_data, out_data, aux): x = in_data[0] y = np.maximum(x, np.zeros_like(x)) self.assign(out_data[0], req[0], y)
def dynamic_masking(self, input_ids, valid_lengths): # TODO(zheyuye), two additional flag `disallow_from_mask` and `already_masked` # that control the masking status for each positions in the sequence. """ Generate masking positions on-the-fly instead of during preprocessing Parameters ---------- input_ids The batchified input_ids with shape (batch_size, max_seq_length) valid_lengths The batchified valid_lengths with shape (batch_size, ) Returns ------ masked_input_ids The masked input sequence with 15% tokens are masked with [MASK] shape (batch_size, max_seq_length) length_masks The masking matrix for the whole sequence that indicates the positions are greater than valid_length. shape (batch_size, max_seq_length) unmasked_tokens The original tokens that appear in the unmasked input sequence shape (batch_size, num_masked_positions) masked_positions The masking positions in mx.np.ndarray with shape (batch_size, num_masked_positions) shape (batch_size, num_masked_positions) masked_lm_weights The weight matrix containing 0 or 1 to mark the actual effect of masked positions shape (batch_size, num_masked_positions) """ N = self._max_num_masked_position # Only valid token without special token are allowed to mask valid_candidates = np.ones_like(input_ids, dtype=np.bool) ignore_tokens = [ self.vocab.cls_id, self.vocab.sep_id, self.vocab.pad_id ] for ignore_token in ignore_tokens: # TODO(zheyuye), Update when operation += supported valid_candidates = valid_candidates * \ np.not_equal(input_ids, ignore_token) valid_lengths = valid_lengths.astype(np.float32) valid_candidates = valid_candidates.astype(np.float32) num_masked_position = mxnp.maximum( 1, np.minimum(N, round(valid_lengths * self._mask_prob))) # Get the masking probability of each position sample_probs = self._proposal_distribution * valid_candidates sample_probs /= mxnp.sum(sample_probs, axis=-1, keepdims=True) sample_probs = npx.stop_gradient(sample_probs) gumbels = mxnp.random.gumbel(np.zeros_like(sample_probs)) # Following the instruction of official repo to avoid deduplicate postions # with Top_k Sampling as https://github.com/google-research/electra/issues/41 masked_positions = npx.topk(mxnp.log(sample_probs) + gumbels, k=N, axis=-1, ret_typ='indices', dtype=np.int32) masked_weights = npx.sequence_mask(mxnp.ones_like(masked_positions), sequence_length=num_masked_position, use_sequence_length=True, axis=1, value=0) masked_positions = masked_positions * masked_weights length_masks = npx.sequence_mask(mxnp.ones_like(input_ids, dtype=np.float32), sequence_length=valid_lengths, use_sequence_length=True, axis=1, value=0) unmasked_tokens = select_vectors_by_position( input_ids, masked_positions) * masked_weights masked_weights = masked_weights.astype(np.float32) replaced_positions = (mxnp.random.uniform( mxnp.zeros_like(masked_positions), mxnp.ones_like( masked_positions)) < self._replace_prob) * masked_positions # dealing with multiple zero values in replaced_positions which causes # the [CLS] being replaced filled = mxnp.where(replaced_positions, self.vocab.mask_id, self.vocab.cls_id).astype(np.int32) # Masking token by replacing with [MASK] masked_input_ids = update_vectors_by_position(input_ids, filled, replaced_positions) # Note: It is likely have multiple zero values in masked_positions if number of masked of # positions not reached the maximum. However, this example hardly exists since valid_length # is almost always equal to max_seq_length masked_input = self.MaskedInput(input_ids=masked_input_ids, masks=length_masks, unmasked_tokens=unmasked_tokens, masked_positions=masked_positions, masked_weights=masked_weights) return masked_input
def forward(self, source: np.ndarray, source_length: np.ndarray, restrict_lexicon: Optional[lexicon.TopKLexicon], raw_constraint_list: List[Optional[constrained.RawConstraintList]], raw_avoid_list: List[Optional[constrained.RawConstraintList]], max_output_lengths: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[Optional[np.ndarray]], List[Optional[constrained.ConstrainedHypothesis]]]: """ Translates multiple sentences using beam search. :param source: Source ids. Shape: (batch_size, bucket_key, num_factors). :param source_length: Valid source lengths. Shape: (batch_size,). :param restrict_lexicon: Lexicon to use for vocabulary restriction. :param raw_constraint_list: A list of optional lists containing phrases (as lists of target word IDs) that must appear in each output. :param raw_avoid_list: A list of optional lists containing phrases (as lists of target word IDs) that must NOT appear in each output. :param max_output_lengths: ndarray of maximum output lengths per input in source. Shape: (batch_size,). Dtype: int32. :return List of best hypotheses indices, list of best word indices, array of accumulated length-normalized negative log-probs, hypotheses lengths, predicted lengths of references (if any), constraints (if any). """ batch_size = source.shape[0] logger.debug("beam_search batch size: %d", batch_size) # Maximum beam search iterations (determined by longest input with eos) max_iterations = max_output_lengths.max().item() logger.debug("max beam search iterations: %d", max_iterations) sample_best_hyp_indices = None if self._sample is not None: utils.check_condition(restrict_lexicon is None, "Sampling is not available when working with a restricted lexicon.") sample_best_hyp_indices = np.arange(0, batch_size * self.beam_size, dtype='int32', ctx=self.context) # General data structure: batch_size * beam_size blocks in total; # a full beam for each sentence, followed by the next beam-block for the next sentence and so on # best word_indices (also act as input: (batch*beam, num_target_factors best_word_indices = np.full((batch_size * self.beam_size, self.num_target_factors), fill_value=self.bos_id, ctx=self.context, dtype='int32') # offset for hypothesis indices in batch decoding offset = np.repeat(np.arange(0, batch_size * self.beam_size, self.beam_size, dtype='int32', ctx=self.context), self.beam_size) # locations of each batch item when first dimension is (batch * beam) batch_indices = np.arange(0, batch_size * self.beam_size, self.beam_size, dtype='int32', ctx=self.context) first_step_mask = np.full((batch_size * self.beam_size, 1), fill_value=np.inf, ctx=self.context, dtype=self.dtype) first_step_mask[batch_indices] = 0.0 # Best word and hypotheses indices across beam search steps from topk operation. best_hyp_indices_list = [] # type: List[np.ndarray] best_word_indices_list = [] # type: List[np.ndarray] lengths = np.zeros((batch_size * self.beam_size, 1), ctx=self.context, dtype='int32') finished = np.zeros((batch_size * self.beam_size, 1), ctx=self.context, dtype='int32') # Extending max_output_lengths to shape (batch_size * beam_size, 1) max_output_lengths = np.repeat(np.expand_dims(max_output_lengths, axis=1), self.beam_size, axis=0) # scores_accumulated: chosen smallest scores in scores (ascending). scores_accumulated = np.zeros((batch_size * self.beam_size, 1), ctx=self.context, dtype=self.dtype) output_vocab_size = self.output_vocab_size # If using a top-k lexicon, select param rows for logit computation that correspond to the # target vocab for this sentence. vocab_slice_ids = None # type: Optional[np.ndarrays] if restrict_lexicon: source_words = np.squeeze(np.split(source, self.num_source_factors, axis=2)[0], axis=2) vocab_slice_ids, output_vocab_size, raw_constraint_list = _get_vocab_slice_ids(restrict_lexicon, source_words, raw_constraint_list, self.eos_id, beam_size=1) pad_dist = np.full((batch_size * self.beam_size, output_vocab_size - 1), fill_value=np.inf, ctx=self.context, dtype=self.dtype) eos_dist = np.full((batch_size * self.beam_size, output_vocab_size), fill_value=np.inf, ctx=self.context, dtype=self.dtype) eos_dist[:, C.EOS_ID] = 0 unk_dist = None if self.prevent_unk: unk_dist = np.zeros_like(eos_dist) unk_dist[:, C.UNK_ID] = np.inf # pylint: disable=E1137 # Initialize the beam to track constraint sets, where target-side lexical constraints are present constraints = constrained.init_batch(raw_constraint_list, self.beam_size, self.bos_id, self.eos_id) if self.global_avoid_trie or any(raw_avoid_list): avoid_states = constrained.AvoidBatch(batch_size, self.beam_size, avoid_list=raw_avoid_list, global_avoid_trie=self.global_avoid_trie) avoid_states.consume(best_word_indices[:, 0]) # constraints operate only on primary target factor # (0) encode source sentence, returns a list model_states, estimated_reference_lengths = self._inference.encode_and_initialize(source, source_length) # repeat states to beam_size model_states = _repeat_states(model_states, self.beam_size, self._inference.state_structure()) # repeat estimated_reference_lengths to shape (batch_size * beam_size, 1) estimated_reference_lengths = np.repeat(estimated_reference_lengths, self.beam_size, axis=0) # Records items in the beam that are inactive. At the beginning (t==1), there is only one valid or active # item on the beam for each sentence inactive = np.zeros((batch_size * self.beam_size, 1), dtype='int32', ctx=self.context) t = 1 for t in range(1, max_iterations + 1): # max_iterations + 1 required to get correct results # (1) obtain next predictions and advance models' state # target_dists: (batch_size * beam_size, target_vocab_size) target_dists, model_states, target_factors = self._inference.decode_step(best_word_indices, model_states, vocab_slice_ids) # (2) Produces the accumulated cost of target words in each row. # There is special treatment for finished and inactive rows: inactive rows are inf everywhere; # finished rows are inf everywhere except column zero, which holds the accumulated model score scores, lengths = self._update_scores(target_dists, finished, inactive, scores_accumulated, lengths, max_output_lengths, unk_dist, pad_dist, eos_dist) # Mark entries that should be blocked as having a score of np.inf if self.global_avoid_trie or any(raw_avoid_list): block_indices = avoid_states.avoid() if len(block_indices) > 0: scores[block_indices] = np.inf if self._sample is not None: target_dists[block_indices] = np.inf # (3) Get beam_size winning hypotheses for each sentence block separately. Only look as # far as the active beam size for each sentence. if self._sample is not None: best_hyp_indices, best_word_indices, scores_accumulated = self._sample(scores, target_dists, finished, sample_best_hyp_indices) else: # On the first timestep, all hypotheses have identical histories, so force topk() to choose extensions # of the first row only by setting all other rows to inf if t == 1: scores += first_step_mask best_hyp_indices, best_word_indices, scores_accumulated = self._top(scores, offset) # Constraints for constrained decoding are processed sentence by sentence if any(raw_constraint_list): best_hyp_indices, best_word_indices, scores_accumulated, constraints, inactive = constrained.topk( t, batch_size, self.beam_size, inactive, scores, constraints, best_hyp_indices, best_word_indices, scores_accumulated) # Map from restricted to full vocab ids if needed if restrict_lexicon: best_word_indices = np.take(vocab_slice_ids, best_word_indices, axis=0) # (4) Normalize the scores of newly finished hypotheses. Note that after this until the # next call to topk(), hypotheses may not be in sorted order. _sort_inputs = [best_hyp_indices, best_word_indices, finished, scores_accumulated, lengths, estimated_reference_lengths] if target_factors is not None: _sort_inputs.append(target_factors) best_word_indices, finished, scores_accumulated, lengths, estimated_reference_lengths = \ self._sort_norm_and_update_finished(*_sort_inputs) # Collect best hypotheses, best word indices best_word_indices_list.append(best_word_indices) best_hyp_indices_list.append(best_hyp_indices) if self._should_stop(finished, batch_size): break # (5) update models' state with winning hypotheses (ascending) model_states = self._sort_states(best_hyp_indices, *model_states) logger.debug("Finished after %d out of %d steps.", t, max_iterations) # (9) Sort the hypotheses within each sentence (normalization for finished hyps may have unsorted them). scores_accumulated_shape = scores_accumulated.shape folded_accumulated_scores = scores_accumulated.reshape((batch_size, -1)) indices = np.argsort(folded_accumulated_scores.astype('float32', copy=False), axis=1).reshape((-1,)) best_hyp_indices = np.unravel_index(indices, scores_accumulated_shape)[0].astype('int32') + offset scores_accumulated = scores_accumulated.take(best_hyp_indices, axis=0) best_hyp_indices_list.append(best_hyp_indices) lengths = lengths.take(best_hyp_indices, axis=0) all_best_hyp_indices = np.stack(best_hyp_indices_list, axis=1) all_best_word_indices = np.stack(best_word_indices_list, axis=2) constraints = [constraints[x] for x in best_hyp_indices.tolist()] return all_best_hyp_indices, \ all_best_word_indices, \ scores_accumulated, \ lengths.astype('int32', copy=False), \ estimated_reference_lengths, \ constraints
############### 2.1.5. Saving Memory ############### # import from mxnet import np, npx npx.set_np() # memoory id 저장하기 x = np.arange(12).reshape(3, 4) y = np.array([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]]) before = id(y) y = y + x id(y) == before # we can assign the result of an operation to a previously allocated array with slice notation z = np.zeros_like(y) # y랑 shape은 같게 값은 0으로 print('id(z):', id(z)) z[:] = x + y print('id(z):', id(z)) # if the value of x is not reused in subsequent computations, we can also use x[:] = x + y or x += y # to reduce the memory overhead of the operation before = id(x) x += y id(x) == before
def forward(self, inputs, token_types, valid_length, original_tokens, masked_positions): """Getting the mlm scores of each masked positions from a generator, then produces the corrupted tokens sampling from a gumbel distribution. We also get the ground-truth and scores of the replaced token detection which is output by a discriminator. The ground-truth is an array with same shape as the input using 1 stand for original token and 0 for replacement. Notice: There is a problem when the masked positions have duplicate indexs. Try to avoid that in the data preprocessing process. In addition, loss calculation should be done in the training scripts as well. Parameters ---------- F inputs The masked input - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) token_types The token types. - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. valid_length The valid length of each sequence. Shape (batch_size,) original_tokens The original tokens that appear in the unmasked input sequence. Shape (batch_size, num_masked_positions). masked_positions : The masked position of the sequence. Shape (batch_size, num_masked_positions). Returns ------- mlm_scores The masked language model score. Shape (batch_size, num_masked_positions, vocab_size) rtd_scores The replaced-token-detection score. Predicts whether the tokens are replaced or not. - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) replaced_inputs Shape (batch_size, num_masked_positions) labels - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) """ if self._uniform_generator: # generate the corrupt tokens randomly with a mlm_scores vector whose value is all 0 zero_logits = np.zeros((1, 1, self.vocab_size), dtype=self._dtype) mlm_scores = np.expand_dims(np.zeros_like(masked_positions, dtype=self._dtype), axis=-1) mlm_scores = mlm_scores + zero_logits else: _, _, mlm_scores = self.generator(inputs, token_types, valid_length, masked_positions) corrupted_tokens, fake_data, labels = self.get_corrupted_tokens( inputs, original_tokens, masked_positions, mlm_scores) # The discriminator takes the same input as the generator and the token_ids are # replaced with fake data _, _, rtd_scores = self.discriminator(fake_data, token_types, valid_length) return mlm_scores, rtd_scores, corrupted_tokens, labels
def dropout(X, drop_prob): assert 0 <= drop_prob <= 1 if drop_prob == 1: return np.zeros_like(X) mask = np.random.uniform(0, 1, X.shape) > drop_prob return mask.astype(np.float32) * X / (1.0 - drop_prob)
def get_initial_embedding(self, inputs, token_types=None): """Get the initial token embeddings that considers the token type and positional embeddings Parameters ---------- F inputs - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) token_types - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) If None, it will be initialized as all zero Returns ------- embedding The initial embedding that will be fed into the encoder """ if self._layout == 'NT': batch_axis, time_axis = 0, 1 elif self._layout == 'TN': batch_axis, time_axis = 1, 0 else: raise NotImplementedError word_embedding = self.word_embed(inputs) if self.trigram_embed: if self._layout == 'NT': word_embedding = np.concatenate([ np.pad(word_embedding[:, 1:], ((0, 0), (0, 1), (0, 0))), word_embedding, np.pad(word_embedding[:, :-1], ((0, 0), (1, 0), (0, 0))) ], axis=-1) elif self._layout == 'TN': word_embedding = np.concatenate([ np.pad(word_embedding[1:, :], ((0, 1), (0, 0), (0, 0))), word_embedding, np.pad(word_embedding[:-1, :], ((1, 0), (0, 0), (0, 0))) ], axis=-1) else: raise NotImplementedError # Projecting the embedding into units only for word embedding if self.trigram_embed or self.embed_size != self.units: word_embedding = self.embed_factorized_proj(word_embedding) if token_types is None: token_types = np.zeros_like(inputs) type_embedding = self.token_type_embed(token_types) embedding = word_embedding + type_embedding if self.pos_embed_type is not None: positional_embedding =\ self.token_pos_embed(npx.arange_like(embedding, axis=time_axis)) positional_embedding = np.expand_dims(positional_embedding, axis=batch_axis) embedding = embedding + positional_embedding # Extra layer normalization plus dropout embedding = self.embed_layer_norm(embedding) embedding = self.embed_dropout(embedding) return embedding