Exemplo n.º 1
0
    def decode_step(self,
                    step_input: np.ndarray,
                    states: List[np.ndarray],
                    vocab_slice_ids: Optional[np.ndarray] = None):
        outputs = []  # type: List[np.ndarray]
        new_states = []  # type: List[np.ndarray]
        factor_outputs = []  # type: List[List[np.ndarray]]
        state_index = 0
        for model, model_state_structure in zip(self._models, self.state_structure()):
            model_states = states[state_index:state_index+len(model_state_structure)]
            state_index += len(model_state_structure)
            logits, model_states, target_factor_outputs = model.decode_step(step_input, model_states, vocab_slice_ids)
            probs = npx.softmax(logits, axis=-1, temperature=self._softmax_temperature)
            outputs.append(probs)
            target_factor_probs = [npx.softmax(tfo, axis=-1) for tfo in target_factor_outputs]
            factor_outputs.append(target_factor_probs)
            new_states += model_states
        scores = self._interpolation(outputs)

        target_factors = None  # type: Optional[np.ndarray]
        if factor_outputs:
            # target factors are greedily 'decoded'.
            factor_predictions = [npx.cast(np.expand_dims(np.argmin(self._interpolation(fs), axis=-1), axis=1), dtype='int32')
                                  for fs in zip(*factor_outputs)]
            if factor_predictions:
                target_factors = factor_predictions[0] if len(factor_predictions) == 1 \
                    else np.concatenate(factor_predictions, axis=1)
        return scores, new_states, target_factors
Exemplo n.º 2
0
def test_argmin():
    A = np.ones((INT_OVERFLOW, 2))
    A[10][1] = -1
    A.attach_grad()
    with mx.autograd.record():
        B = np.argmin(A)
    print(B)
    assert B == 21
    B.backward()
    assert A.grad.shape == (INT_OVERFLOW, 2)
    assert A.grad[0][0] == 0