Esempio n. 1
0
    def update(self, labels: mx.nd.NDArray, embeddings0: mx.nd.NDArray,
               embeddings1: mx.nd.NDArray):
        """

        :param labels: NDArray.
        :param embeddings0: NDArray.
        :param embeddings1: NDArray.
        :return:
        """

        embeddings0 = embeddings0.asnumpy() if not isinstance(
            embeddings0, np.ndarray) else embeddings0
        embeddings1 = embeddings1.asnumpy() if not isinstance(
            embeddings1, np.ndarray) else embeddings1
        labels = labels.asnumpy() if not isinstance(labels,
                                                    np.ndarray) else labels

        if self._dist_type == 0:
            diff = np.subtract(embeddings0, embeddings1)
            dists = np.sqrt(np.sum(np.square(diff), 1))
        else:
            dists = 1 - np.sum(np.multiply(embeddings0, embeddings1), axis=1) / \
                    (np.linalg.norm(embeddings0, axis=1) * np.linalg.norm(embeddings1, axis=1))

        self._dists += [d for d in dists]
        self._issame += [l for l in labels]
Esempio n. 2
0
def prune(scores: mx.nd.NDArray, finished: mx.nd.NDArray,
          inf_array: mx.nd.NDArray, beam_size: int,
          prune_threshold: float) -> mx.nd.NDArray:
    """
    Returns a 0-1 array indicating which hypotheses are inactive based on pruning.
    Finished hypotheses that have a score worse than prune_threshold from the best scoring hypotheses
    are marked as inactive.

    :param scores: Hypotheses scores. Shape: (batch * beam, 1).
    :param finished: 0-1 array indicating which hypotheses are finished. Shape: (batch * beam,).
    :param inf_array: Auxiliary array filled with infinity. Shape: (batch * beam,).
    :param beam_size: Beam size.
    :param prune_threshold: Pruning threshold.
    :return NDArray of inactive items. Shape(batch * beam,).
    """
    scores = scores.reshape((-1, beam_size))
    finished = finished.reshape((-1, beam_size))
    inf_array = inf_array.reshape((-1, beam_size))

    # best finished scores. Shape: (batch, 1)
    best_finished_scores = mx.nd.where(finished, scores,
                                       inf_array).min(axis=1, keepdims=True)
    inactive = mx.nd.cast((scores - best_finished_scores) > prune_threshold,
                          dtype='int32').reshape((-1))
    return inactive
Esempio n. 3
0
def bbox_overlaps(anchors: mx.nd.NDArray, gt: mx.nd.NDArray):
    """
    Get IoU of the anchors and ground truth bounding boxes.
    The shape of anchors and gt should be (N, 4) and (M, 4)
    So the shape of return value is (N, M)
    """
    N, M = anchors.shape[0], gt.shape[0]
    anchors_mat = anchors.reshape((N, 1, 4)).broadcast_to((N, M, 4)).reshape(
        (-1, 4))
    gt_mat = gt.reshape((1, M, 4)).broadcast_to((N, M, 4)).reshape((-1, 4))
    # inter
    x0 = nd.max(nd.stack(anchors_mat[:, 0], gt_mat[:, 0]), axis=0)
    y0 = nd.max(nd.stack(anchors_mat[:, 1], gt_mat[:, 1]), axis=0)
    x1 = nd.min(nd.stack(anchors_mat[:, 2], gt_mat[:, 2]), axis=0)
    y1 = nd.min(nd.stack(anchors_mat[:, 3], gt_mat[:, 3]), axis=0)

    inter = _get_area(
        nd.concatenate([
            x0.reshape((-1, 1)),
            y0.reshape((-1, 1)),
            x1.reshape((-1, 1)),
            y1.reshape((-1, 1))
        ],
                       axis=1))
    outer = _get_area(anchors_mat) + _get_area(gt_mat) - inter
    iou = inter / outer
    iou = iou.reshape((N, M))
    return iou
Esempio n. 4
0
    def decode_step(self,
                    step_input: mx.nd.NDArray,
                    states: List,
                    vocab_slice_ids: Optional[mx.nd.NDArray] = None):
        batch_beam_size, num_target_factors = step_input.shape
        print('step_input', step_input.asnumpy())

        internal_lengths, num_decode_step_calls = states
        num_decode_step_calls = num_decode_step_calls.asscalar()
        if num_decode_step_calls == 0:  # first call to decode_step, we expect step input to be all <bos>
            assert (step_input.asnumpy() == C.BOS_ID).all()

        if step_input[:, 0].asscalar() == C.BOS_ID:
            # predict word id 4 given <bos>
            scores = mx.nd.array([0, 0, 0, 0, 1])
        elif step_input[:, 0].asscalar() == C.EOS_ID:
            # predict pad given <eos>
            scores = mx.nd.array([1, 0, 0, 0, 0])
        else:
            # otherwise always predict pad
            scores = mx.nd.array([0, 0, 0, 0, 1])

        # topk is minimizing
        scores *= -1
        #outputs = mx.nd.array([self.predictor.get(inp, C.PAD_ID) for inp in step_input.asnumpy().tolist()], ctx=step_input.context)
        #scores = mx.nd.one_hot(outputs, depth=self.output_vocab_size)

        internal_lengths += 1
        num_decode_step_calls += 1

        self.states = states = [
            internal_lengths,
            mx.nd.array([num_decode_step_calls], dtype='int32')
        ]
        return scores, states, None
Esempio n. 5
0
    def reorder(self, indices: mx.nd.NDArray) -> None:
        """
        Reorders the avoid list according to the selected row indices.
        This can produce duplicates, but this is fixed if state changes occur in consume().

        :param indices: An mx.nd.NDArray containing indices of hypotheses to select.
        """
        if self.global_avoid_states:
            self.global_avoid_states = [self.global_avoid_states[x] for x in indices.asnumpy()]

        if self.local_avoid_states:
            self.local_avoid_states = [self.local_avoid_states[x] for x in indices.asnumpy()]
Esempio n. 6
0
    def reorder(self, indices: mx.nd.NDArray) -> None:
        """
        Reorders the avoid list according to the selected row indices.
        This can produce duplicates, but this is fixed if state changes occur in consume().

        :param indices: An mx.nd.NDArray containing indices of hypotheses to select.
        """
        if self.global_avoid_states:
            self.global_avoid_states = [self.global_avoid_states[x] for x in indices.asnumpy()]

        if self.local_avoid_states:
            self.local_avoid_states = [self.local_avoid_states[x] for x in indices.asnumpy()]
Esempio n. 7
0
def topk(
    scores: mx.nd.NDArray, k: int, offset: mx.nd.NDArray
) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, mx.nd.NDArray]:
    """
    Get the lowest k elements per sentence from a `scores` matrix.

    :param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size)
    :param k: The number of smallest scores to return.
    :param offset: Array to add to the hypothesis indices for offsetting in batch decoding.
    :return: The row indices, column indices and values of the k smallest items in matrix.
    """
    # (batch_size, beam_size * target_vocab_size)
    folded_scores = scores.reshape((-1, k * scores.shape[-1]))
    batch_size = folded_scores.shape[0]

    # pylint: disable=unbalanced-tuple-unpacking
    values, indices = mx.nd.topk(folded_scores,
                                 axis=1,
                                 k=k,
                                 ret_typ='both',
                                 is_ascend=True)
    indices = mx.nd.cast(indices, 'int32').reshape((-1, ))
    best_hyp_indices, best_word_indices = mx.nd.unravel_index(
        indices, scores.shape)

    if batch_size > 1:
        # Offsetting the indices to match the shape of the scores matrix
        best_hyp_indices += offset

    values = values.reshape((-1, 1))
    return best_hyp_indices, best_word_indices, values
Esempio n. 8
0
def topk(scores: mx.nd.NDArray,
         offset: mx.nd.NDArray,
         k: int) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, mx.nd.NDArray]:
    """
    Get the lowest k elements per sentence from a `scores` matrix.
    At the first timestep, the shape of scores is (batch, target_vocabulary_size).
    At subsequent steps, the shape is (batch * k, target_vocabulary_size).

    :param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size)
    :param offset: Array (shape: batch_size * k) containing offsets to add to the hypothesis indices in batch decoding.
    :param k: The number of smallest scores to return.
    :return: The row indices, column indices and values of the k smallest items in matrix.
    """

    # Compute the batch size from the offsets and k. We don't know the batch size because it is
    # either 1 (at timestep 1) or k (at timesteps 2+).
    # (batch_size, beam_size * target_vocab_size)
    batch_size = int(offset.shape[-1] / k)
    folded_scores = scores.reshape((batch_size, -1))

    # pylint: disable=unbalanced-tuple-unpacking
    values, indices = mx.nd.topk(folded_scores, axis=1, k=k, ret_typ='both', is_ascend=True)
    indices = mx.nd.cast(indices, 'int32').reshape((-1,))
    best_hyp_indices, best_word_indices = mx.nd.unravel_index(indices, shape=(batch_size * k, scores.shape[-1]))

    if batch_size > 1:
        # Offsetting the indices to match the shape of the scores matrix
        best_hyp_indices += offset

    values = values.reshape((-1, 1))
    return best_hyp_indices, best_word_indices, values
Esempio n. 9
0
def smallest_k_mx(
    matrix: mx.nd.NDArray,
    k: int,
    only_first_row: bool = False
) -> Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray]:
    """
    Find the smallest elements in a NDarray.

    :param matrix: Any matrix.
    :param k: The number of smallest elements to return.
    :param only_first_row: If True the search is constrained to the first row of the matrix.
    :return: The row indices, column indices and values of the k smallest items in matrix.
    """

    if only_first_row:
        folded_matrix = mx.nd.reshape(matrix[0], shape=(-1))
    else:
        folded_matrix = matrix.reshape(-1)
    # pylint: disable=unbalanced-tuple-unpacking
    values, indices = mx.nd.topk(folded_matrix,
                                 axis=None,
                                 k=k,
                                 ret_typ='both',
                                 is_ascend=True,
                                 dtype="int32")

    indices = np.unravel_index(
        indices.reshape(-1).astype(np.int64).asnumpy(), matrix.shape)
    logger.info(indices)
    return indices, values
Esempio n. 10
0
def isfinite(data: mx.nd.NDArray) -> mx.nd.NDArray:
    """Performs an element-wise check to determine if the NDArray contains an infinite element or not.
       TODO: remove this funciton after upgrade to MXNet 1.4.* in favor of mx.ndarray.contrib.isfinite()
    """
    is_data_not_nan = data == data
    is_data_not_infinite = data.abs() != np.inf
    return mx.nd.logical_and(is_data_not_infinite, is_data_not_nan)
Esempio n. 11
0
def encode_mx_ndarray(v: mx.nd.NDArray) -> Any:
    return {
        "__kind__": kind_inst,
        "class": "mxnet.nd.array",
        "args": encode([v.asnumpy().tolist()]),
        "kwargs": {"dtype": encode(v.dtype)},
    }
Esempio n. 12
0
def split(data: mx.nd.NDArray,
          num_outputs: int,
          axis: int = 1,
          squeeze_axis: bool = False) -> List[mx.nd.NDArray]:
    """
    Version of mxnet.ndarray.split that always returns a list.  The original
    implementation only returns a list if num_outputs > 1:
    https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.split

    Splits an array along a particular axis into multiple sub-arrays.

    :param data: The input.
    :param num_outputs: Number of splits. Note that this should evenly divide
                        the length of the axis.
    :param axis: Axis along which to split.
    :param squeeze_axis: If true, Removes the axis with length 1 from the shapes
                         of the output arrays.
    :return: List of NDArrays resulting from the split.
    """
    ndarray_or_list = data.split(num_outputs=num_outputs,
                                 axis=axis,
                                 squeeze_axis=squeeze_axis)
    if num_outputs == 1:
        return [ndarray_or_list]
    return ndarray_or_list
Esempio n. 13
0
    def update(self, preds: mx.nd.NDArray, labels: mx.nd.NDArray):
        """
        Update confusion table used to compute scores.

        :param preds: predicted classes for samples in batch.
        :type preds: mx.ndarray.NDArray

        :param labels: actual labels for samples in batch.
        :type labels: mx.ndarray.NDArray
        """
        preds = preds.reshape((-1, 1))
        labels = labels.reshape((-1, 1))

        preds_binary = (preds == self.all_labels)
        labels_binary = (labels == self.all_labels)

        self.confusion_matrix += mx.nd.dot(labels_binary.T, preds_binary)
Esempio n. 14
0
    def consume(self, word_ids: mx.nd.NDArray) -> None:
        """
        Consumes a word for each trie, updating respective states.

        :param word_ids: The set of word IDs.
        """
        word_ids = word_ids.asnumpy().tolist()
        for i, word_id in enumerate(word_ids):
            if self.global_avoid_states:
                self.global_avoid_states[i] = self.global_avoid_states[i].consume(word_id)
            if self.local_avoid_states:
                self.local_avoid_states[i] = self.local_avoid_states[i].consume(word_id)
Esempio n. 15
0
    def consume(self, word_ids: mx.nd.NDArray) -> None:
        """
        Consumes a word for each trie, updating respective states.

        :param word_ids: The set of word IDs.
        """
        word_ids = word_ids.asnumpy().tolist()
        for i, word_id in enumerate(word_ids):
            if self.global_avoid_states:
                self.global_avoid_states[i] = self.global_avoid_states[i].consume(word_id)
            if self.local_avoid_states:
                self.local_avoid_states[i] = self.local_avoid_states[i].consume(word_id)
Esempio n. 16
0
    def update(self, labels: mx.nd.NDArray, embeddings0: mx.nd.NDArray,
               embeddings1: mx.nd.NDArray):
        """

        :param labels: NDArray.
        :param embeddings0: NDArray.
        :param embeddings1: NDArray.
        :return:
        """

        embeddings0 = embeddings0.asnumpy() if not isinstance(
            embeddings0, np.ndarray) else embeddings0
        embeddings1 = embeddings1.asnumpy() if not isinstance(
            embeddings1, np.ndarray) else embeddings1
        labels = labels.asnumpy() if not isinstance(labels,
                                                    np.ndarray) else labels

        diff = np.subtract(embeddings0, embeddings1)
        dists = np.sqrt(np.sum(np.square(diff), 1))
        self._dists += [d for d in dists]

        self._issame += [l for l in labels]
Esempio n. 17
0
def topk(
    scores: mx.nd.NDArray, k: int, batch_size: int, offset: mx.nd.NDArray,
    use_mxnet_topk: bool
) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, mx.nd.NDArray]:
    """
    Get the lowest k elements per sentence from a `scores` matrix.

    :param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size)
    :param k: The number of smallest scores to return.
    :param batch_size: Number of sentences being decoded at once.
    :param offset: Array to add to the hypothesis indices for offsetting in batch decoding.
    :param use_mxnet_topk: True to use the mxnet implementation or False to use the numpy one.
    :return: The row indices, column indices and values of the k smallest items in matrix.
    """
    # (batch_size, beam_size * target_vocab_size)
    folded_scores = scores.reshape((batch_size, k * scores.shape[-1]))

    if use_mxnet_topk:
        # pylint: disable=unbalanced-tuple-unpacking
        values, indices = mx.nd.topk(folded_scores,
                                     axis=1,
                                     k=k,
                                     ret_typ='both',
                                     is_ascend=True)
        best_hyp_indices, best_word_indices = mx.nd.array(np.unravel_index(
            indices.astype(np.int32).asnumpy().ravel(), scores.shape),
                                                          dtype='int32',
                                                          ctx=scores.context)

    else:
        folded_scores = folded_scores.asnumpy()
        # Get the scores
        # Indexes into folded_scores: (batch_size, beam_size)
        flat_idxs = np.argpartition(folded_scores, range(k))[:, :k]
        # Score values: (batch_size, beam_size)
        values = mx.nd.array(
            folded_scores[np.arange(folded_scores.shape[0])[:, None],
                          flat_idxs],
            ctx=scores.context)
        best_hyp_indices, best_word_indices = mx.nd.array(np.unravel_index(
            flat_idxs.ravel(), scores.shape),
                                                          dtype='int32',
                                                          ctx=scores.context)

    if batch_size > 1:
        # Offsetting the indices to match the shape of the scores matrix
        best_hyp_indices += offset

    values = values.reshape((-1, 1))
    return best_hyp_indices, best_word_indices, values
Esempio n. 18
0
    def run_decoder(self,
                    sequences: mx.nd.NDArray,
                    bucket_key: Tuple[int, int],
                    model_state: 'ModelState') -> Tuple[mx.nd.NDArray, mx.nd.NDArray, 'ModelState']:
        """
        Runs forward pass of the single-step decoder.

        :return: Probability distribution over next word, attention scores, updated model state.
        """
        batch = mx.io.DataBatch(
            data=[sequences.as_in_context(self.context)] + model_state.states,
            label=None,
            bucket_key=bucket_key,
            provide_data=self._get_decoder_data_shapes(bucket_key))
        self.decoder_module.forward(data_batch=batch, is_train=False)
        probs, attention_probs, *model_state.states = self.decoder_module.get_outputs()
        return probs, attention_probs, model_state
Esempio n. 19
0
def smallest_k_mx_batched(
    matrix: mx.nd.NDArray,
    k: int,
    batch_size: int,
    offset: mx.nd.NDArray,
    only_first_row: bool = False
) -> Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray]:
    """
    Find the smallest elements in a NDarray.

    :param matrix: Any matrix.
    :param k: The number of smallest elements to return.
    :param only_first_row: If True the search is constrained to the first row of the matrix.
    :return: The row indices, column indices and values of the k smallest items in matrix.
    """

    if only_first_row:
        folded_matrix = matrix.reshape((-4, batch_size, -1, 0, 0))
        folded_matrix = folded_matrix[:, 0, :, :]
    else:
        folded_matrix = matrix

    folded_matrix = folded_matrix.reshape((batch_size, -1))
    # pylint: disable=unbalanced-tuple-unpacking
    values, indices = mx.nd.topk(folded_matrix,
                                 axis=1,
                                 k=k,
                                 ret_typ='both',
                                 is_ascend=True,
                                 dtype="int32")
    # best_hyp_indices, best_hyp_pos_indices, best_word_indices = mx.nd.array(np.unravel_index(indices.astype(np.int32).asnumpy().ravel(), matrix.shape),
    #                       dtype='int32',
    #                       ctx=matrix.context)
    best_hyp_indices, best_hyp_pos_indices, best_word_indices = mx.nd.array(
        mx.nd.unravel_index(
            indices.astype(np.int32).reshape(-1), matrix.shape),
        dtype='int32',
        ctx=matrix.context)
    if batch_size > 1:
        best_hyp_indices += offset


#    assert(mx.nd.nansum(values.reshape(-1) - matrix[best_hyp_indices, best_hyp_pos_indices, best_word_indices]) == 0)
    return (best_hyp_indices, best_hyp_pos_indices,
            best_word_indices), values.reshape(-1)
Esempio n. 20
0
def topk(scores: mx.nd.NDArray,
         k: int,
         batch_size: int,
         offset: mx.nd.NDArray,
         use_mxnet_topk: bool) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, mx.nd.NDArray]:
    """
    Get the lowest k elements per sentence from a `scores` matrix.

    :param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size)
    :param k: The number of smallest scores to return.
    :param batch_size: Number of sentences being decoded at once.
    :param offset: Array to add to the hypothesis indices for offsetting in batch decoding.
    :param use_mxnet_topk: True to use the mxnet implementation or False to use the numpy one.
    :return: The row indices, column indices and values of the k smallest items in matrix.
    """
    # (batch_size, beam_size * target_vocab_size)
    folded_scores = scores.reshape((batch_size, k * scores.shape[-1]))

    if use_mxnet_topk:
        # pylint: disable=unbalanced-tuple-unpacking
        values, indices = mx.nd.topk(folded_scores, axis=1, k=k, ret_typ='both', is_ascend=True)
        best_hyp_indices, best_word_indices = mx.nd.array(np.unravel_index(indices.astype(np.int32).asnumpy().ravel(),
                                                                           scores.shape),
                                                          dtype='int32',
                                                          ctx=scores.context)

    else:
        folded_scores = folded_scores.asnumpy()
        # Get the scores
        # Indexes into folded_scores: (batch_size, beam_size)
        flat_idxs = np.argpartition(folded_scores, range(k))[:, :k]
        # Score values: (batch_size, beam_size)
        values = mx.nd.array(folded_scores[np.arange(folded_scores.shape[0])[:, None], flat_idxs], ctx=scores.context)
        best_hyp_indices, best_word_indices = mx.nd.array(np.unravel_index(flat_idxs.ravel(), scores.shape),
                                                          dtype='int32', ctx=scores.context)

    if batch_size > 1:
        # Offsetting the indices to match the shape of the scores matrix
        best_hyp_indices += offset

    values = values.reshape((-1, 1))
    return best_hyp_indices, best_word_indices, values
Esempio n. 21
0
    def run_decoder(self,
                    encoded_source: mx.nd.NDArray,
                    dynamic_source: mx.nd.NDArray,
                    source_length: mx.nd.NDArray,
                    previous_word_id: mx.nd.NDArray,
                    previous_hidden: mx.nd.NDArray,
                    decoder_states: List[mx.nd.NDArray],
                    bucket_key: int) -> Tuple[mx.nd.NDArray, mx.nd.NDArray,
                                              mx.nd.NDArray, mx.nd.NDArray,
                                              List[mx.nd.NDArray]]:
        """
        Runs forward pass of the single-step decoder.

        :param encoded_source: Encoded source sentence.
        :param dynamic_source: Dynamic encoding of source sentence.
        :param source_length: Source length.
        :param previous_word_id: Previous predicted word id.
        :param previous_hidden: Previous hidden decoder state.
        :param decoder_states: Decoder states.
        :param bucket_key: Bucket key.
        :return: Probability distribution over next word, attention scores, dynamic source encoding,
                 next hidden state, next decoder states.
        """

        data = [encoded_source,
                dynamic_source,
                source_length,
                previous_word_id.as_in_context(self.context),
                previous_hidden] + decoder_states

        decoder_batch = mx.io.DataBatch(
            data=data,
            label=None, bucket_key=bucket_key, provide_data=self._get_decoder_data_shapes(bucket_key))
        # run forward pass
        self.decoder_module.forward(data_batch=decoder_batch, is_train=False)
        # collect outputs
        softmax_out, attention_probs, dynamic_source, next_hidden, *next_layer_states = \
            self.decoder_module.get_outputs()

        return softmax_out, attention_probs, dynamic_source, next_hidden, next_layer_states
Esempio n. 22
0
def split(data: mx.nd.NDArray,
          num_outputs: int,
          axis: int = 1,
          squeeze_axis: bool = False) -> List[mx.nd.NDArray]:
    """
    Version of mxnet.ndarray.split that always returns a list.  The original
    implementation only returns a list if num_outputs > 1:
    https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.split

    Splits an array along a particular axis into multiple sub-arrays.

    :param data: The input.
    :param num_outputs: Number of splits. Note that this should evenly divide
                        the length of the axis.
    :param axis: Axis along which to split.
    :param squeeze_axis: If true, Removes the axis with length 1 from the shapes
                         of the output arrays.
    :return: List of NDArrays resulting from the split.
    """
    ndarray_or_list = data.split(num_outputs=num_outputs, axis=axis, squeeze_axis=squeeze_axis)
    if num_outputs == 1:
        return [ndarray_or_list]
    return ndarray_or_list
Esempio n. 23
0
def numpy_topk(
    scores: mx.nd.NDArray, k: int, offset: mx.nd.NDArray
) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, mx.nd.NDArray]:
    """
    Get the lowest k elements per sentence from a `scores` matrix using an intermediary Numpy conversion.
    This should be equivalent to sockeye.utils.topk() and is used as a comparative implementation in testing.

    :param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size)
    :param k: The number of smallest scores to return.
    :param offset: Array to add to the hypothesis indices for offsetting in batch decoding.
    :return: The row indices, column indices and values of the k smallest items in matrix.
    """
    # (batch_size, beam_size * target_vocab_size)
    folded_scores = scores.reshape((-1, k * scores.shape[-1]))
    batch_size = folded_scores.shape[0]

    folded_scores = folded_scores.asnumpy()
    # Get the scores
    # Indexes into folded_scores: (batch_size, beam_size)
    flat_idxs = np.argpartition(folded_scores, range(k))[:, :k]
    # Score values: (batch_size, beam_size)
    values = mx.nd.array(folded_scores[np.arange(folded_scores.shape[0])[:,
                                                                         None],
                                       flat_idxs],
                         ctx=scores.context)
    best_hyp_indices, best_word_indices = mx.nd.array(np.unravel_index(
        flat_idxs.ravel(), scores.shape),
                                                      dtype='int32',
                                                      ctx=scores.context)

    if batch_size > 1:
        # Offsetting the indices to match the shape of the scores matrix
        best_hyp_indices += offset

    values = values.reshape((-1, 1))
    return best_hyp_indices, best_word_indices, values
Esempio n. 24
0
    def forward(
        self, source: mx.nd.NDArray, source_length: mx.nd.NDArray,
        restrict_lexicon: Optional[lexicon.TopKLexicon],
        raw_constraint_list: List[Optional[constrained.RawConstraintList]],
        raw_avoid_list: List[Optional[constrained.RawConstraintList]],
        max_output_lengths: mx.nd.NDArray
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray,
               List[Optional[np.ndarray]],
               List[Optional[constrained.ConstrainedHypothesis]]]:
        """
        Translates multiple sentences using beam search.

        :param source: Source ids. Shape: (batch_size, bucket_key, num_factors).
        :param source_length: Valid source lengths. Shape: (batch_size,).
        :param restrict_lexicon: Lexicon to use for vocabulary restriction.
        :param raw_constraint_list: A list of optional lists containing phrases (as lists of target word IDs)
               that must appear in each output.
        :param raw_avoid_list: A list of optional lists containing phrases (as lists of target word IDs)
               that must NOT appear in each output.
        :param max_output_lengths: NDArray of maximum output lengths per input in source.
                Shape: (batch_size,). Dtype: int32.
        :return List of best hypotheses indices, list of best word indices,
                array of accumulated length-normalized negative log-probs, hypotheses lengths,
                predicted lengths of references (if any), constraints (if any).
        """
        batch_size = source.shape[0]
        logger.debug("beam_search batch size: %d", batch_size)

        # Maximum beam search iterations (determined by longest input with eos)
        max_iterations = max_output_lengths.max().asscalar()
        logger.debug("max beam search iterations: %d", max_iterations)

        sample_best_hyp_indices = None
        if self._sample is not None:
            utils.check_condition(
                restrict_lexicon is None,
                "Sampling is not available when working with a restricted lexicon."
            )
            sample_best_hyp_indices = mx.nd.arange(0,
                                                   batch_size * self.beam_size,
                                                   dtype='int32',
                                                   ctx=self.context)

        # General data structure: batch_size * beam_size blocks in total;
        # a full beam for each sentence, followed by the next beam-block for the next sentence and so on

        best_word_indices = mx.nd.full((batch_size * self.beam_size, ),
                                       val=self.bos_id,
                                       ctx=self.context,
                                       dtype='int32')

        # offset for hypothesis indices in batch decoding
        offset = mx.nd.repeat(
            mx.nd.arange(0,
                         batch_size * self.beam_size,
                         self.beam_size,
                         dtype='int32',
                         ctx=self.context), self.beam_size)

        # locations of each batch item when first dimension is (batch * beam)
        batch_indices = mx.nd.arange(0,
                                     batch_size * self.beam_size,
                                     self.beam_size,
                                     dtype='int32',
                                     ctx=self.context)
        first_step_mask = mx.nd.full((batch_size * self.beam_size, 1),
                                     val=np.inf,
                                     ctx=self.context,
                                     dtype=self.dtype)
        first_step_mask[batch_indices] = 1.0
        pad_dist = mx.nd.full(
            (batch_size * self.beam_size, self.output_vocab_size - 1),
            val=np.inf,
            ctx=self.context,
            dtype=self.dtype)
        eos_dist = mx.nd.full(
            (batch_size * self.beam_size, self.output_vocab_size),
            val=np.inf,
            ctx=self.context,
            dtype=self.dtype)
        eos_dist[:, C.EOS_ID] = 0

        # Best word and hypotheses indices across beam search steps from topk operation.
        best_hyp_indices_list = []  # type: List[mx.nd.NDArray]
        best_word_indices_list = []  # type: List[mx.nd.NDArray]

        lengths = mx.nd.zeros((batch_size * self.beam_size, ),
                              ctx=self.context,
                              dtype='int32')
        finished = mx.nd.zeros((batch_size * self.beam_size, ),
                               ctx=self.context,
                               dtype='int32')

        # Extending max_output_lengths to shape (batch_size * beam_size,)
        max_output_lengths = mx.nd.repeat(max_output_lengths, self.beam_size)

        # scores_accumulated: chosen smallest scores in scores (ascending).
        scores_accumulated = mx.nd.zeros((batch_size * self.beam_size, 1),
                                         ctx=self.context,
                                         dtype=self.dtype)

        # If using a top-k lexicon, select param rows for logit computation that correspond to the
        # target vocab for this sentence.
        vocab_slice_ids = None  # type: Optional[mx.nd.NDArray]
        if restrict_lexicon:
            source_words = utils.split(source,
                                       num_outputs=self.num_source_factors,
                                       axis=2,
                                       squeeze_axis=True)[0]
            vocab_slice_ids = restrict_lexicon.get_trg_ids(
                source_words.astype("int32").asnumpy())
            if any(raw_constraint_list):
                # Add the constraint IDs to the list of permissibled IDs, and then project them into the reduced space
                constraint_ids = np.array([
                    word_id for sent in raw_constraint_list for phr in sent
                    for word_id in phr
                ])
                vocab_slice_ids = np.lib.arraysetops.union1d(
                    vocab_slice_ids, constraint_ids)
                full_to_reduced = dict(
                    (val, i) for i, val in enumerate(vocab_slice_ids))
                raw_constraint_list = [[[full_to_reduced[x] for x in phr]
                                        for phr in sent]
                                       for sent in raw_constraint_list]
            # Pad to a multiple of 8.
            vocab_slice_ids = np.pad(vocab_slice_ids,
                                     (0, 7 - ((len(vocab_slice_ids) - 1) % 8)),
                                     mode='constant',
                                     constant_values=self.eos_id)
            vocab_slice_ids = mx.nd.array(vocab_slice_ids,
                                          ctx=self.context,
                                          dtype='int32')

            if vocab_slice_ids.shape[0] < self.beam_size + 1:
                # This fixes an edge case for toy models, where the number of vocab ids from the lexicon is
                # smaller than the beam size.
                logger.warning(
                    "Padding vocab_slice_ids (%d) with EOS to have at least %d+1 elements to expand",
                    vocab_slice_ids.shape[0], self.beam_size)
                n = self.beam_size - vocab_slice_ids.shape[0] + 1
                vocab_slice_ids = mx.nd.concat(vocab_slice_ids,
                                               mx.nd.full((n, ),
                                                          val=self.eos_id,
                                                          ctx=self.context,
                                                          dtype='int32'),
                                               dim=0)

            pad_dist = mx.nd.full(
                (batch_size * self.beam_size, vocab_slice_ids.shape[0] - 1),
                val=np.inf,
                ctx=self.context)
            eos_dist = mx.nd.full(
                (batch_size * self.beam_size, vocab_slice_ids.shape[0]),
                val=np.inf,
                ctx=self.context)
            eos_dist[:, C.EOS_ID] = 0

        # Initialize the beam to track constraint sets, where target-side lexical constraints are present
        constraints = constrained.init_batch(raw_constraint_list,
                                             self.beam_size, self.bos_id,
                                             self.eos_id)

        if self.global_avoid_trie or any(raw_avoid_list):
            avoid_states = constrained.AvoidBatch(
                batch_size,
                self.beam_size,
                avoid_list=raw_avoid_list,
                global_avoid_trie=self.global_avoid_trie)
            avoid_states.consume(best_word_indices)

        # (0) encode source sentence, returns a list
        model_states, estimated_reference_lengths = self._inference.encode_and_initialize(
            source, source_length)
        # repeat states to beam_size
        model_states = _repeat_states(model_states, self.beam_size,
                                      self._inference.state_structure())

        # Records items in the beam that are inactive. At the beginning (t==1), there is only one valid or active
        # item on the beam for each sentence
        inactive = mx.nd.zeros((batch_size * self.beam_size),
                               dtype='int32',
                               ctx=self.context)
        t = 1
        for t in range(
                1, max_iterations + 1
        ):  # TODO: max_iterations + 1 is the MINIMUM to get correct results right now
            # (1) obtain next predictions and advance models' state
            # target_dists: (batch_size * beam_size, target_vocab_size)
            target_dists, model_states = self._inference.decode_step(
                best_word_indices, model_states, vocab_slice_ids)

            # (2) Produces the accumulated cost of target words in each row.
            # There is special treatment for finished and inactive rows: inactive rows are inf everywhere;
            # finished rows are inf everywhere except column zero, which holds the accumulated model score
            scores, lengths = self._update_scores(target_dists, finished,
                                                  inactive, scores_accumulated,
                                                  lengths, max_output_lengths,
                                                  pad_dist, eos_dist)

            # Mark entries that should be blocked as having a score of np.inf
            if self.global_avoid_trie or any(raw_avoid_list):
                block_indices = avoid_states.avoid()
                if len(block_indices) > 0:
                    scores[block_indices] = np.inf
                    if self._sample is not None:
                        target_dists[block_indices] = np.inf

            # (3) Get beam_size winning hypotheses for each sentence block separately. Only look as
            # far as the active beam size for each sentence.
            if self._sample is not None:
                best_hyp_indices, best_word_indices, scores_accumulated = self._sample(
                    scores, target_dists, finished, sample_best_hyp_indices)
            else:
                # On the first timestep, all hypotheses have identical histories, so force topk() to choose extensions
                # of the first row only by setting all other rows to inf
                if t == 1:
                    scores *= first_step_mask

                best_hyp_indices, best_word_indices, scores_accumulated = self._top(
                    scores, offset)

            # Constraints for constrained decoding are processed sentence by sentence
            if any(raw_constraint_list):
                best_hyp_indices, best_word_indices, scores_accumulated, constraints, inactive = constrained.topk(
                    t, batch_size, self.beam_size, inactive, scores,
                    constraints, best_hyp_indices, best_word_indices,
                    scores_accumulated)

            # Map from restricted to full vocab ids if needed
            if restrict_lexicon:
                best_word_indices = vocab_slice_ids.take(best_word_indices)

            # (4) Normalize the scores of newly finished hypotheses. Note that after this until the
            # next call to topk(), hypotheses may not be in sorted order.
            finished, scores_accumulated, lengths, estimated_reference_lengths = self._sort_norm_and_update_finished(
                best_hyp_indices, best_word_indices, finished,
                scores_accumulated, lengths, estimated_reference_lengths)

            # Collect best hypotheses, best word indices
            best_hyp_indices_list.append(best_hyp_indices)
            best_word_indices_list.append(best_word_indices)

            if self._should_stop(finished, batch_size):
                break

            # (5) update models' state with winning hypotheses (ascending)
            model_states = self._sort_states(best_hyp_indices, *model_states)

        logger.debug("Finished after %d out of %d steps.", t, max_iterations)

        # (9) Sort the hypotheses within each sentence (normalization for finished hyps may have unsorted them).
        folded_accumulated_scores = scores_accumulated.reshape(
            (batch_size, self.beam_size * scores_accumulated.shape[-1]))
        indices = mx.nd.cast(mx.nd.argsort(
            folded_accumulated_scores.astype('float32'), axis=1),
                             dtype='int32').reshape((-1, ))
        best_hyp_indices, _ = mx.nd.unravel_index(
            indices, scores_accumulated.shape) + offset
        scores_accumulated = scores_accumulated.take(best_hyp_indices)
        best_hyp_indices_list.append(best_hyp_indices)
        lengths = lengths.take(best_hyp_indices)
        all_best_hyp_indices = mx.nd.stack(*best_hyp_indices_list, axis=1)
        all_best_word_indices = mx.nd.stack(*best_word_indices_list, axis=1)
        constraints = [constraints[x] for x in best_hyp_indices.asnumpy()]

        return all_best_hyp_indices.asnumpy(), \
               all_best_word_indices.asnumpy(), \
               scores_accumulated.asnumpy(), \
               lengths.asnumpy().astype('int32'), \
               estimated_reference_lengths.asnumpy(), \
               constraints