def forward(self, best_hyp_indices, best_word_indices, finished, scores_accumulated, lengths, reference_lengths, factors=None): # Reorder fixed-size beam data according to best_hyp_indices (ascending) finished = np.take(finished, best_hyp_indices, axis=0) lengths = np.take(lengths, best_hyp_indices, axis=0) reference_lengths = np.take(reference_lengths, best_hyp_indices, axis=0) # Normalize hypotheses that JUST finished all_finished = np.expand_dims(np.logical_or(best_word_indices == self.pad_id, best_word_indices == self.eos_id), axis=1) newly_finished = np.logical_xor(all_finished, finished) scores_accumulated = np.where(newly_finished, self._scorer(scores_accumulated, npx.cast(lengths, self.dtype), reference_lengths), scores_accumulated) # Recompute finished. Hypotheses are finished if they are extended with <pad> or <eos> finished = np.logical_or(best_word_indices == self.pad_id, best_word_indices == self.eos_id) finished = npx.cast(np.expand_dims(finished, axis=1), 'int32') # Concatenate sorted secondary target factors to best_word_indices. Shape: (batch*beam, num_factors) best_word_indices = np.expand_dims(best_word_indices, axis=1) if factors is not None: secondary_factors = np.take(factors, best_hyp_indices, axis=0) best_word_indices = np.concatenate((best_word_indices, secondary_factors), axis=1) return best_word_indices, finished, scores_accumulated, lengths, reference_lengths
def decode_step(self, step_input: np.ndarray, states: List[np.ndarray], vocab_slice_ids: Optional[np.ndarray] = None): outputs = [] # type: List[np.ndarray] new_states = [] # type: List[np.ndarray] factor_outputs = [] # type: List[List[np.ndarray]] state_index = 0 for model, model_state_structure in zip(self._models, self.state_structure()): model_states = states[state_index:state_index+len(model_state_structure)] state_index += len(model_state_structure) logits, model_states, target_factor_outputs = model.decode_step(step_input, model_states, vocab_slice_ids) probs = npx.softmax(logits, axis=-1, temperature=self._softmax_temperature) outputs.append(probs) target_factor_probs = [npx.softmax(tfo, axis=-1) for tfo in target_factor_outputs] factor_outputs.append(target_factor_probs) new_states += model_states scores = self._interpolation(outputs) target_factors = None # type: Optional[np.ndarray] if factor_outputs: # target factors are greedily 'decoded'. factor_predictions = [npx.cast(np.expand_dims(np.argmin(self._interpolation(fs), axis=-1), axis=1), dtype='int32') for fs in zip(*factor_outputs)] if factor_predictions: target_factors = factor_predictions[0] if len(factor_predictions) == 1 \ else np.concatenate(factor_predictions, axis=1) return scores, new_states, target_factors
def test_cast(): A = np.ones((INT_OVERFLOW, 2)) A.attach_grad() with mx.autograd.record(): B = npx.cast(A, dtype='float16') assert B.shape == (INT_OVERFLOW, 2) assert B[0][0] == 1 B.backward() assert A.grad.shape == (INT_OVERFLOW, 2) assert A.grad[0][0] == 1
def decode_step(self, step_input: np.ndarray, states: List, vocab_slice_ids: Optional[np.ndarray] = None): logits, states, target_factor_outputs = self._model.decode_step(step_input, states, vocab_slice_ids) if not self._skip_softmax: logits = npx.log_softmax(logits, axis=-1, temperature=self._softmax_temperature) scores = -logits target_factors = None # type: Optional[np.ndarray] if target_factor_outputs: # target factors are greedily 'decoded'. factor_predictions = [npx.cast(np.expand_dims(np.argmax(tfo, axis=1), axis=1), dtype='int32') for tfo in target_factor_outputs] target_factors = factor_predictions[0] if len(factor_predictions) == 1 \ else np.concatenate(factor_predictions, axis=1) return scores, states, target_factors