Exemple #1
0
    def forward(self, best_hyp_indices, best_word_indices,
                finished, scores_accumulated, lengths, reference_lengths,
                factors=None):

        # Reorder fixed-size beam data according to best_hyp_indices (ascending)
        finished = np.take(finished, best_hyp_indices, axis=0)
        lengths = np.take(lengths, best_hyp_indices, axis=0)
        reference_lengths = np.take(reference_lengths, best_hyp_indices, axis=0)

        # Normalize hypotheses that JUST finished
        all_finished = np.expand_dims(np.logical_or(best_word_indices == self.pad_id,
                                                    best_word_indices == self.eos_id),
                                      axis=1)
        newly_finished = np.logical_xor(all_finished, finished)

        scores_accumulated = np.where(newly_finished,
                                      self._scorer(scores_accumulated,
                                                   npx.cast(lengths, self.dtype),
                                                   reference_lengths),
                                      scores_accumulated)

        # Recompute finished. Hypotheses are finished if they are extended with <pad> or <eos>
        finished = np.logical_or(best_word_indices == self.pad_id, best_word_indices == self.eos_id)
        finished = npx.cast(np.expand_dims(finished, axis=1), 'int32')

        # Concatenate sorted secondary target factors to best_word_indices. Shape: (batch*beam, num_factors)
        best_word_indices = np.expand_dims(best_word_indices, axis=1)

        if factors is not None:
            secondary_factors = np.take(factors, best_hyp_indices, axis=0)
            best_word_indices = np.concatenate((best_word_indices, secondary_factors), axis=1)

        return best_word_indices, finished, scores_accumulated, lengths, reference_lengths
Exemple #2
0
    def decode_step(self,
                    step_input: np.ndarray,
                    states: List[np.ndarray],
                    vocab_slice_ids: Optional[np.ndarray] = None):
        outputs = []  # type: List[np.ndarray]
        new_states = []  # type: List[np.ndarray]
        factor_outputs = []  # type: List[List[np.ndarray]]
        state_index = 0
        for model, model_state_structure in zip(self._models, self.state_structure()):
            model_states = states[state_index:state_index+len(model_state_structure)]
            state_index += len(model_state_structure)
            logits, model_states, target_factor_outputs = model.decode_step(step_input, model_states, vocab_slice_ids)
            probs = npx.softmax(logits, axis=-1, temperature=self._softmax_temperature)
            outputs.append(probs)
            target_factor_probs = [npx.softmax(tfo, axis=-1) for tfo in target_factor_outputs]
            factor_outputs.append(target_factor_probs)
            new_states += model_states
        scores = self._interpolation(outputs)

        target_factors = None  # type: Optional[np.ndarray]
        if factor_outputs:
            # target factors are greedily 'decoded'.
            factor_predictions = [npx.cast(np.expand_dims(np.argmin(self._interpolation(fs), axis=-1), axis=1), dtype='int32')
                                  for fs in zip(*factor_outputs)]
            if factor_predictions:
                target_factors = factor_predictions[0] if len(factor_predictions) == 1 \
                    else np.concatenate(factor_predictions, axis=1)
        return scores, new_states, target_factors
def test_cast():
    A = np.ones((INT_OVERFLOW, 2))
    A.attach_grad()
    with mx.autograd.record():
        B = npx.cast(A, dtype='float16')
    assert B.shape == (INT_OVERFLOW, 2)
    assert B[0][0] == 1
    B.backward()
    assert A.grad.shape == (INT_OVERFLOW, 2)
    assert A.grad[0][0] == 1
Exemple #4
0
    def decode_step(self,
                    step_input: np.ndarray,
                    states: List,
                    vocab_slice_ids: Optional[np.ndarray] = None):
        logits, states, target_factor_outputs = self._model.decode_step(step_input, states, vocab_slice_ids)
        if not self._skip_softmax:
            logits = npx.log_softmax(logits, axis=-1, temperature=self._softmax_temperature)
        scores = -logits

        target_factors = None  # type: Optional[np.ndarray]
        if target_factor_outputs:
            # target factors are greedily 'decoded'.
            factor_predictions = [npx.cast(np.expand_dims(np.argmax(tfo, axis=1), axis=1), dtype='int32') for tfo in target_factor_outputs]
            target_factors = factor_predictions[0] if len(factor_predictions) == 1 \
                else np.concatenate(factor_predictions, axis=1)
        return scores, states, target_factors