def decode_step(self, step_input: np.ndarray, states: List[np.ndarray], vocab_slice_ids: Optional[np.ndarray] = None): outputs = [] # type: List[np.ndarray] new_states = [] # type: List[np.ndarray] factor_outputs = [] # type: List[List[np.ndarray]] state_index = 0 for model, model_state_structure in zip(self._models, self.state_structure()): model_states = states[state_index:state_index+len(model_state_structure)] state_index += len(model_state_structure) logits, model_states, target_factor_outputs = model.decode_step(step_input, model_states, vocab_slice_ids) probs = npx.softmax(logits, axis=-1, temperature=self._softmax_temperature) outputs.append(probs) target_factor_probs = [npx.softmax(tfo, axis=-1) for tfo in target_factor_outputs] factor_outputs.append(target_factor_probs) new_states += model_states scores = self._interpolation(outputs) target_factors = None # type: Optional[np.ndarray] if factor_outputs: # target factors are greedily 'decoded'. factor_predictions = [npx.cast(np.expand_dims(np.argmin(self._interpolation(fs), axis=-1), axis=1), dtype='int32') for fs in zip(*factor_outputs)] if factor_predictions: target_factors = factor_predictions[0] if len(factor_predictions) == 1 \ else np.concatenate(factor_predictions, axis=1) return scores, new_states, target_factors
def test_argmin(): A = np.ones((INT_OVERFLOW, 2)) A[10][1] = -1 A.attach_grad() with mx.autograd.record(): B = np.argmin(A) print(B) assert B == 21 B.backward() assert A.grad.shape == (INT_OVERFLOW, 2) assert A.grad[0][0] == 0