def update(self, labels: mx.nd.NDArray, embeddings0: mx.nd.NDArray, embeddings1: mx.nd.NDArray): """ :param labels: NDArray. :param embeddings0: NDArray. :param embeddings1: NDArray. :return: """ embeddings0 = embeddings0.asnumpy() if not isinstance( embeddings0, np.ndarray) else embeddings0 embeddings1 = embeddings1.asnumpy() if not isinstance( embeddings1, np.ndarray) else embeddings1 labels = labels.asnumpy() if not isinstance(labels, np.ndarray) else labels if self._dist_type == 0: diff = np.subtract(embeddings0, embeddings1) dists = np.sqrt(np.sum(np.square(diff), 1)) else: dists = 1 - np.sum(np.multiply(embeddings0, embeddings1), axis=1) / \ (np.linalg.norm(embeddings0, axis=1) * np.linalg.norm(embeddings1, axis=1)) self._dists += [d for d in dists] self._issame += [l for l in labels]
def decode_step(self, step_input: mx.nd.NDArray, states: List, vocab_slice_ids: Optional[mx.nd.NDArray] = None): batch_beam_size, num_target_factors = step_input.shape print('step_input', step_input.asnumpy()) internal_lengths, num_decode_step_calls = states num_decode_step_calls = num_decode_step_calls.asscalar() if num_decode_step_calls == 0: # first call to decode_step, we expect step input to be all <bos> assert (step_input.asnumpy() == C.BOS_ID).all() if step_input[:, 0].asscalar() == C.BOS_ID: # predict word id 4 given <bos> scores = mx.nd.array([0, 0, 0, 0, 1]) elif step_input[:, 0].asscalar() == C.EOS_ID: # predict pad given <eos> scores = mx.nd.array([1, 0, 0, 0, 0]) else: # otherwise always predict pad scores = mx.nd.array([0, 0, 0, 0, 1]) # topk is minimizing scores *= -1 #outputs = mx.nd.array([self.predictor.get(inp, C.PAD_ID) for inp in step_input.asnumpy().tolist()], ctx=step_input.context) #scores = mx.nd.one_hot(outputs, depth=self.output_vocab_size) internal_lengths += 1 num_decode_step_calls += 1 self.states = states = [ internal_lengths, mx.nd.array([num_decode_step_calls], dtype='int32') ] return scores, states, None
def reorder(self, indices: mx.nd.NDArray) -> None: """ Reorders the avoid list according to the selected row indices. This can produce duplicates, but this is fixed if state changes occur in consume(). :param indices: An mx.nd.NDArray containing indices of hypotheses to select. """ if self.global_avoid_states: self.global_avoid_states = [self.global_avoid_states[x] for x in indices.asnumpy()] if self.local_avoid_states: self.local_avoid_states = [self.local_avoid_states[x] for x in indices.asnumpy()]
def encode_mx_ndarray(v: mx.nd.NDArray) -> Any: return { "__kind__": kind_inst, "class": "mxnet.nd.array", "args": encode([v.asnumpy().tolist()]), "kwargs": {"dtype": encode(v.dtype)}, }
def consume(self, word_ids: mx.nd.NDArray) -> None: """ Consumes a word for each trie, updating respective states. :param word_ids: The set of word IDs. """ word_ids = word_ids.asnumpy().tolist() for i, word_id in enumerate(word_ids): if self.global_avoid_states: self.global_avoid_states[i] = self.global_avoid_states[i].consume(word_id) if self.local_avoid_states: self.local_avoid_states[i] = self.local_avoid_states[i].consume(word_id)
def update(self, labels: mx.nd.NDArray, embeddings0: mx.nd.NDArray, embeddings1: mx.nd.NDArray): """ :param labels: NDArray. :param embeddings0: NDArray. :param embeddings1: NDArray. :return: """ embeddings0 = embeddings0.asnumpy() if not isinstance( embeddings0, np.ndarray) else embeddings0 embeddings1 = embeddings1.asnumpy() if not isinstance( embeddings1, np.ndarray) else embeddings1 labels = labels.asnumpy() if not isinstance(labels, np.ndarray) else labels diff = np.subtract(embeddings0, embeddings1) dists = np.sqrt(np.sum(np.square(diff), 1)) self._dists += [d for d in dists] self._issame += [l for l in labels]