Exemplo n.º 1
0
    def get_logits(self, hidden):
        """Get all the logits.

        Parameters
        ----------
        F
        hidden
            The hidden representation
            Shape (..., in_units)

        Returns
        -------
        logits
            Shape (..., |V|)

        """
        if self._cutoffs is None:
            if self._in_units != self._embed_size:
                hidden = self.inter_proj_l[0](hidden)
            logits = self.out_proj_l[0](hidden)
            return logits
        else:
            all_logits = []
            if self._div_val == 1.0:
                if self._in_units == self._embed_size:
                    all_scores = self.out_proj_l[0](hidden)
                    tail_cluster_scores = self.tail_cluster_score_proj(hidden)
                else:
                    inter_hidden = self.inter_proj_l[0](hidden)
                    all_scores = self.out_proj_l[0](inter_hidden)
                    tail_cluster_scores = self.tail_cluster_score_proj(
                        inter_hidden)
                all_scores_l = np.split(all_scores, self._cutoffs, axis=-1)
                head_scores = all_scores_l[0]
            else:
                inter_hidden = self.inter_proj_l[0](hidden)
                head_scores = self.out_proj_l[0](inter_hidden)
                tail_cluster_scores = self.tail_cluster_score_proj(
                    inter_hidden)
            head_tail_cluster_logits = \
                npx.log_softmax(np.concatenate([head_scores, tail_cluster_scores],
                                                   axis=-1), axis=-1)
            head_logits, tail_cluster_logits = \
                np.split(head_tail_cluster_logits, [self._cutoffs[0]], axis=-1)
            tail_cluster_logits = np.split(tail_cluster_logits,
                                           self._num_tail_clusters,
                                           axis=-1)
            all_logits.append(head_logits)
            for i in range(1, len(self._cutoffs) + 1):
                if self._div_val == 1.0:
                    ele_scores = all_scores_l[i]
                else:
                    ele_scores = self.out_proj_l[i](
                        self.inter_proj_l[i](hidden))
                ele_logits = npx.log_softmax(ele_scores, axis=-1)
                ele_logits = tail_cluster_logits[-i] + ele_logits
                all_logits.append(ele_logits)
            return np.concatenate(all_logits, axis=-1)
Exemplo n.º 2
0
    def forward(self,
                inputs: np.ndarray,
                previous_states: Optional[np.ndarray] = None,
                input_lengths: Optional[np.ndarray] = None,
                bias: Optional[np.ndarray] = None,
                *args) -> Tuple[np.ndarray, np.ndarray]:  # mypy: ignore
        """
        Computes multi-head attention on a set of inputs, serving as queries, keys, and values.
        If sequence lengths are provided, they will be used to mask the attention scores.
        A bias mask may also be used to mask the attention scores.
        May also use a cache of previously computed inputs.
        Returns a ndarray of shape (batch, max_length, output_depth).

        :param inputs: Input Data. Shape: (max_length, batch, input_depth).
        :param input_lengths: Optional lengths of inputs to mask attention scores. Shape: (batch, 1).
        :param bias: Optional 3d bias tensor to mask attention scores.
        :param previous_states: Optional list with two ndarrays - previous input's keys and values.
                                Shape: 2 * (batch, max_length+1, depth_att).
        :return: ndarray of shape (batch, max_length, output_depth).
        """

        proj = self.ff_in(inputs)
        queries, kv_1, kv_2 = np.split(proj, 3, axis=2)
        states = np.concatenate((kv_1, kv_2), axis=2)

        if previous_states is not None:
            states = np.concatenate((previous_states, states), axis=0)

        return self._attend(queries, states, lengths=input_lengths,
                            bias=bias), states
Exemplo n.º 3
0
    def forward(
        self, 
        hidden_states, 
        self_attn_mask, 
        position_embeddings, 
        mem_states=None, 
        mem_attn_mask=None, 
        mem_position_embeddings=None
    ): 
        """
        hidden_states: 
            - layout = 'NT'
                Shape (B, L_seq, d_model)
            - layout = 'TN'
                Shape (L_seq, B, d_model)
        """
        # NT -> NTK: (B, L_seq, inner_dim) -> (B, L_seq, num_heads, n_kv)
        # TN -> TNK: (L_seq, B, inner_dim) -> (L_seq, B, num_heads, n_kv)
        def shape(x):
            return x.reshape(-2, -2, self._num_heads, -1)

        # 1. self-attention
        self_query, self_key, self_value = np.split(
            self.self_attn_qkv(self.self_attn_layer_norm(hidden_states)), 
            indices_or_sections=3, 
            axis=-1
        )
        out, [_, self_attn_weights] = self.self_attn(
            shape(self_query), 
            shape(self_key), 
            shape(self_value), 
            mask=self_attn_mask, 
            edge_scores=position_embeddings
        )
        out = self.dropout(self.self_attn_proj(out))
        out = hidden_states + out

        # 2. cross-attention, if needed
        if self._is_decoder: 
            hidden_states = out
            cross_query, cross_key, cross_value = (
                self.cross_attn_q(self.cross_attn_layer_norm(out)), 
                self.cross_attn_k(mem_states), 
                self.cross_attn_v(mem_states)
            )
            out, [_, cross_attn_weights] = self.cross_attn(
                shape(cross_query), 
                shape(cross_key), 
                shape(cross_value), 
                mask=mem_attn_mask, 
                edge_scores=mem_position_embeddings
            )
            out = self.dropout(self.cross_attn_proj(out))
            out = hidden_states + out

        # 3. feed forward
        out = self.ffn(out)

        return out
Exemplo n.º 4
0
    def forward(self, x, layer_states):
        """

        Parameters
        ----------
        x
            - layout = 'NT'
                Shape (batch_size, seq_length, C_in)
            - layout = 'TN'
                Shape (seq_length, batch_size, C_in)

        layer_states
            - layout = 'NT'
                Shape (2, batch_size, prev_len, C_in)
            - layout = 'TN'
                Shape (2, prev_len, batch_size, C_in)
        """
        x = self.ln(x)
        if self._layout == 'NT':
            batch_axis, time_axis = 0, 1
            prev_len = npx.shape_array(layer_states)[2]
        else:
            batch_axis, time_axis = 1, 0
            prev_len = npx.shape_array(layer_states)[1]

        query, key, value = np.split(self.qkv(x), 3, axis=-1)
        if layer_states is not None:
            prev_key, prev_value = layer_states[0], layer_states[1]
            key = np.concatenate([prev_key, key], axis=time_axis)
            value = np.concatenate([prev_value, value], axis=time_axis)
        new_states = np.stack([key, value], axis=0)

        # gen mask
        query_pos = npx.arange_like(query, axis=time_axis)
        if prev_len is not None:
            query_pos = query_pos + prev_len
        key_pos = npx.arange_like(key, axis=time_axis)
        # (query_len, key_len)
        mask = (npx.reshape(key_pos,
                            (1, -1)) <= npx.reshape(query_pos,
                                                    (-1, 1))).astype(
                                                        self._dtype)
        # broadcast to (batch_size, query_len, key_len)
        mask = npx.broadcast_like(np.expand_dims(mask, axis=0),
                                  query,
                                  lhs_axes=0,
                                  rhs_axes=batch_axis)

        query = npx.reshape(query, (-2, -2, self._num_heads, -1))
        key = npx.reshape(key, (-2, -2, self._num_heads, -1))
        value = npx.reshape(value, (-2, -2, self._num_heads, -1))

        out, [_, attn_weight] = self.attention_cell(query, key, value, mask)
        out = self.out_proj(out)
        out = self.hidden_dropout(out)

        return out, new_states
Exemplo n.º 5
0
    def dnp_func(a, b=None, split_inputs=(), ret_type=list):
        """
        Dummy Doc:
        dnp_func is using the same np.xxx operators
        """
        ret_lst = []
        # unsupported indexing case
        ret_lst.append(a[:, a[1, :] > 0])
        # unsupported operator
        ret_lst.append(np.nonzero(b))
        # unsupported operator case
        ret_lst.append(tuple(np.split(split_inputs[0], split_inputs[1])))

        return ret_type(ret_lst)
Exemplo n.º 6
0
    def forward(self, data, valid_length):  # pylint: disable=arguments-differ
        # We will catch the optional factor weights in kwargs
        average_factors_embeds = []  # type: List[np.ndarray]
        concat_factors_embeds = []  # type: List[np.ndarray]
        sum_factors_embeds = []  # type: List[np.ndarray]
        if self.config.num_factors > 1 and self.config.factor_configs is not None:
            data, *data_factors = (np.squeeze(x, axis=2) for x in np.split(data, self.config.num_factors, axis=2))
            for i, (factor_data, factor_config) in enumerate(zip(data_factors,
                                                                 self.config.factor_configs)):
                factor_weight = self.factor_weights[i]
                factor_embedding = npx.embedding(factor_data,
                                                 input_dim=factor_config.vocab_size,
                                                 weight=factor_weight.data(),
                                                 output_dim=factor_config.num_embed)
                if factor_config.combine == C.FACTORS_COMBINE_CONCAT:
                    concat_factors_embeds.append(factor_embedding)
                elif factor_config.combine == C.FACTORS_COMBINE_SUM:
                    sum_factors_embeds.append(factor_embedding)
                elif factor_config.combine == C.FACTORS_COMBINE_AVERAGE:
                    average_factors_embeds.append(factor_embedding)
                else:
                    raise ValueError("Unknown combine value for factors: %s" % factor_config.combine)
        else:
            data = np.squeeze(data, axis=2)

        embed = npx.embedding(data,
                              weight=self.weight.data(),
                              input_dim=self.config.vocab_size,
                              output_dim=self.config.num_embed,
                              dtype=self._dtype,
                              sparse_grad=False)

        if self.config.num_factors > 1 and self.config.factor_configs is not None:
            if average_factors_embeds:
                embed = npx.add_n(embed, *average_factors_embeds) / (len(average_factors_embeds) + 1)
            if sum_factors_embeds:
                embed = npx.add_n(embed, *sum_factors_embeds)
            if concat_factors_embeds:
                embed = np.concatenate((embed, *concat_factors_embeds), axis=2)

        if self.config.dropout > 0:
            embed = npx.dropout(data=embed, p=self.config.dropout)

        return embed, np.copy(valid_length)  # See https://github.com/apache/incubator-mxnet/issues/14228
Exemplo n.º 7
0
    def forward(self, data, mem, rel_positions, mask, query_r_bias,
                query_k_bias):
        """

        Parameters
        ----------
        F
        data
            The input data.
            layout = 'NT':
                Shape (batch_size, query_length, units)
            layout = 'TN':
                Shape (query_length, batch_size, units)
        mem
            The memory.
            layout = 'NT':
                Shape (batch_size, mem_length, units)
            layout = 'TN':
                Shape (mem_length, batch_size, units)
        rel_positions
            The relative positions between data and [mem, data]
            Shape (query_length, mem_length + query_length).
            A positive value means that query is after the memory, i.e.,
            query_location - mem_location.
        mask
            Mask between the query and the memory + query.
            1--> will be used, 0 --> won't be used
            Shape (batch_size, query_length, mem_length + query_length)
        query_r_bias
            The query bias for calculating the relative scores
            Shape (num_heads, query_head_units)
        query_k_bias
            The key bias for calculating the relative scores.
            Shape (num_heads, query_head_units)

        Returns
        -------
        out
            - layout = 'NT'
                Shape (batch_size, query_length, units)
            - layout = 'TN'
                Shape (query_length, batch_size, units)
        """
        if self._layout == 'NT':
            context = np.concatenate([mem, data], axis=1)
        elif self._layout == 'TN':
            context = np.concatenate([mem, data], axis=0)
        else:
            raise NotImplementedError
        if self._pre_norm:
            query = self.attn_query(self.layer_norm(data))
            key_value = self.attn_kv(self.layer_norm(context))
            key, value = np.split(key_value, 2, axis=-1)
        else:
            query = self.attn_query(data)
            key_value = self.attn_kv(context)
            key, value = np.split(key_value, 2, axis=-1)
        query = npx.reshape(query, (-2, -2, self._num_heads, -1))
        key = npx.reshape(key, (-2, -2, self._num_heads, -1))
        value = npx.reshape(value, (-2, -2, self._num_heads, -1))
        # Compute attention
        rel_score = self.rel_pos_score_cell(rel_positions,
                                            query + query_r_bias)
        out, _ = self.attn_cell(query + query_k_bias, key, value, mask,
                                rel_score)
        out = self.dropout_layer(out)
        if self._pre_norm:
            out = data + out
        else:
            out = self.layer_norm(data + out)
        out = self.ffn(out)
        return out
Exemplo n.º 8
0
def test_np_split():
    class TestSplit(HybridBlock):
        def __init__(self, indices_or_sections, axis=None):
            super(TestSplit, self).__init__()
            self._axis = axis
            self._indices_or_sections = indices_or_sections

        def hybrid_forward(self, F, a, *args, **kwargs):
            return F.np.split(a,
                              indices_or_sections=self._indices_or_sections,
                              axis=self._axis)

    def get_indices(axis_size):
        if axis_size is 0:
            axis_size = random.randint(3, 6)
        samples = random.randint(1, axis_size - 1)
        indices = sorted(
            random.sample([i for i in range(1, axis_size)], samples))
        indices = tuple(indices)
        return indices

    dim = random.randint(0, 3)
    shape = [0] + [random.randint(2, 4) for i in range(dim)]
    for hybridize in [True, False]:
        for axis in range(len(shape)):
            indices = get_indices(shape[axis])
            sections = 7 if shape[axis] is 0 else shape[axis]
            for indices_or_sections in [indices, sections]:
                # test gluon
                test_split = TestSplit(axis=axis,
                                       indices_or_sections=indices_or_sections)
                if hybridize:
                    test_split.hybridize()

                a = mx.nd.random.uniform(-1.0, 1.0,
                                         shape=shape).as_np_ndarray()
                a.attach_grad()
                expected_ret = _np.split(
                    a.asnumpy(),
                    indices_or_sections=indices_or_sections,
                    axis=axis)
                with mx.autograd.record():
                    y = test_split(a)
                assert len(y) == len(expected_ret)
                for mx_out, np_out in zip(y, expected_ret):
                    assert_almost_equal(mx_out.asnumpy(),
                                        np_out,
                                        rtol=1e-3,
                                        atol=1e-5)

                mx.autograd.backward(y)

                assert_almost_equal(a.grad.asnumpy(),
                                    _np.ones(a.shape),
                                    rtol=1e-3,
                                    atol=1e-5)

                # test imperative
                mx_outs = np.split(a,
                                   indices_or_sections=indices_or_sections,
                                   axis=axis)
                np_outs = _np.split(a.asnumpy(),
                                    indices_or_sections=indices_or_sections,
                                    axis=axis)
                for mx_out, np_out in zip(mx_outs, np_outs):
                    assert_almost_equal(mx_out.asnumpy(),
                                        np_out,
                                        rtol=1e-3,
                                        atol=1e-5)
Exemplo n.º 9
0
    def forward(self,
                source: np.ndarray,
                source_length: np.ndarray,
                restrict_lexicon: Optional[lexicon.TopKLexicon],
                raw_constraint_list: List[Optional[constrained.RawConstraintList]],
                raw_avoid_list: List[Optional[constrained.RawConstraintList]],
                max_output_lengths: np.ndarray) -> Tuple[np.ndarray,
                                                         np.ndarray,
                                                         np.ndarray,
                                                         np.ndarray,
                                                         List[Optional[np.ndarray]],
                                                         List[Optional[constrained.ConstrainedHypothesis]]]:
        """
        Translates multiple sentences using beam search.

        :param source: Source ids. Shape: (batch_size, bucket_key, num_factors).
        :param source_length: Valid source lengths. Shape: (batch_size,).
        :param restrict_lexicon: Lexicon to use for vocabulary restriction.
        :param raw_constraint_list: A list of optional lists containing phrases (as lists of target word IDs)
               that must appear in each output.
        :param raw_avoid_list: A list of optional lists containing phrases (as lists of target word IDs)
               that must NOT appear in each output.
        :param max_output_lengths: ndarray of maximum output lengths per input in source.
                Shape: (batch_size,). Dtype: int32.
        :return List of best hypotheses indices, list of best word indices,
                array of accumulated length-normalized negative log-probs, hypotheses lengths,
                predicted lengths of references (if any), constraints (if any).
        """
        batch_size = source.shape[0]
        logger.debug("beam_search batch size: %d", batch_size)

        # Maximum beam search iterations (determined by longest input with eos)
        max_iterations = max_output_lengths.max().item()
        logger.debug("max beam search iterations: %d", max_iterations)

        sample_best_hyp_indices = None
        if self._sample is not None:
            utils.check_condition(restrict_lexicon is None,
                                  "Sampling is not available when working with a restricted lexicon.")
            sample_best_hyp_indices = np.arange(0, batch_size * self.beam_size, dtype='int32', ctx=self.context)

        # General data structure: batch_size * beam_size blocks in total;
        # a full beam for each sentence, followed by the next beam-block for the next sentence and so on

        # best word_indices (also act as input: (batch*beam, num_target_factors
        best_word_indices = np.full((batch_size * self.beam_size, self.num_target_factors),
                                    fill_value=self.bos_id, ctx=self.context, dtype='int32')

        # offset for hypothesis indices in batch decoding
        offset = np.repeat(np.arange(0, batch_size * self.beam_size, self.beam_size,
                                     dtype='int32', ctx=self.context), self.beam_size)

        # locations of each batch item when first dimension is (batch * beam)
        batch_indices = np.arange(0, batch_size * self.beam_size, self.beam_size, dtype='int32', ctx=self.context)
        first_step_mask = np.full((batch_size * self.beam_size, 1),
                                  fill_value=np.inf, ctx=self.context, dtype=self.dtype)
        first_step_mask[batch_indices] = 0.0

        # Best word and hypotheses indices across beam search steps from topk operation.
        best_hyp_indices_list = []  # type: List[np.ndarray]
        best_word_indices_list = []  # type: List[np.ndarray]

        lengths = np.zeros((batch_size * self.beam_size, 1), ctx=self.context, dtype='int32')
        finished = np.zeros((batch_size * self.beam_size, 1), ctx=self.context, dtype='int32')

        # Extending max_output_lengths to shape (batch_size * beam_size, 1)
        max_output_lengths = np.repeat(np.expand_dims(max_output_lengths, axis=1), self.beam_size, axis=0)

        # scores_accumulated: chosen smallest scores in scores (ascending).
        scores_accumulated = np.zeros((batch_size * self.beam_size, 1), ctx=self.context, dtype=self.dtype)

        output_vocab_size = self.output_vocab_size

        # If using a top-k lexicon, select param rows for logit computation that correspond to the
        # target vocab for this sentence.
        vocab_slice_ids = None  # type: Optional[np.ndarrays]
        if restrict_lexicon:
            source_words = np.squeeze(np.split(source, self.num_source_factors, axis=2)[0], axis=2)
            vocab_slice_ids, output_vocab_size, raw_constraint_list = _get_vocab_slice_ids(restrict_lexicon,
                                                                                           source_words,
                                                                                           raw_constraint_list,
                                                                                           self.eos_id, beam_size=1)

        pad_dist = np.full((batch_size * self.beam_size, output_vocab_size - 1),
                           fill_value=np.inf, ctx=self.context, dtype=self.dtype)
        eos_dist = np.full((batch_size * self.beam_size, output_vocab_size),
                           fill_value=np.inf, ctx=self.context, dtype=self.dtype)
        eos_dist[:, C.EOS_ID] = 0
        unk_dist = None
        if self.prevent_unk:
            unk_dist = np.zeros_like(eos_dist)
            unk_dist[:, C.UNK_ID] = np.inf  # pylint: disable=E1137

        # Initialize the beam to track constraint sets, where target-side lexical constraints are present
        constraints = constrained.init_batch(raw_constraint_list, self.beam_size, self.bos_id, self.eos_id)

        if self.global_avoid_trie or any(raw_avoid_list):
            avoid_states = constrained.AvoidBatch(batch_size, self.beam_size,
                                                  avoid_list=raw_avoid_list,
                                                  global_avoid_trie=self.global_avoid_trie)
            avoid_states.consume(best_word_indices[:, 0])  # constraints operate only on primary target factor

        # (0) encode source sentence, returns a list
        model_states, estimated_reference_lengths = self._inference.encode_and_initialize(source, source_length)
        # repeat states to beam_size
        model_states = _repeat_states(model_states, self.beam_size, self._inference.state_structure())
        # repeat estimated_reference_lengths to shape (batch_size * beam_size, 1)
        estimated_reference_lengths = np.repeat(estimated_reference_lengths, self.beam_size, axis=0)

        # Records items in the beam that are inactive. At the beginning (t==1), there is only one valid or active
        # item on the beam for each sentence
        inactive = np.zeros((batch_size * self.beam_size, 1), dtype='int32', ctx=self.context)
        t = 1
        for t in range(1, max_iterations + 1):  # max_iterations + 1 required to get correct results
            # (1) obtain next predictions and advance models' state
            # target_dists: (batch_size * beam_size, target_vocab_size)
            target_dists, model_states, target_factors = self._inference.decode_step(best_word_indices,
                                                                                     model_states,
                                                                                     vocab_slice_ids)

            # (2) Produces the accumulated cost of target words in each row.
            # There is special treatment for finished and inactive rows: inactive rows are inf everywhere;
            # finished rows are inf everywhere except column zero, which holds the accumulated model score
            scores, lengths = self._update_scores(target_dists,
                                                  finished,
                                                  inactive,
                                                  scores_accumulated,
                                                  lengths,
                                                  max_output_lengths,
                                                  unk_dist,
                                                  pad_dist,
                                                  eos_dist)

            # Mark entries that should be blocked as having a score of np.inf
            if self.global_avoid_trie or any(raw_avoid_list):
                block_indices = avoid_states.avoid()
                if len(block_indices) > 0:
                    scores[block_indices] = np.inf
                    if self._sample is not None:
                        target_dists[block_indices] = np.inf

            # (3) Get beam_size winning hypotheses for each sentence block separately. Only look as
            # far as the active beam size for each sentence.
            if self._sample is not None:
                best_hyp_indices, best_word_indices, scores_accumulated = self._sample(scores,
                                                                                       target_dists,
                                                                                       finished,
                                                                                       sample_best_hyp_indices)
            else:
                # On the first timestep, all hypotheses have identical histories, so force topk() to choose extensions
                # of the first row only by setting all other rows to inf
                if t == 1:
                    scores += first_step_mask

                best_hyp_indices, best_word_indices, scores_accumulated = self._top(scores, offset)

            # Constraints for constrained decoding are processed sentence by sentence
            if any(raw_constraint_list):
                best_hyp_indices, best_word_indices, scores_accumulated, constraints, inactive = constrained.topk(
                    t,
                    batch_size,
                    self.beam_size,
                    inactive,
                    scores,
                    constraints,
                    best_hyp_indices,
                    best_word_indices,
                    scores_accumulated)

            # Map from restricted to full vocab ids if needed
            if restrict_lexicon:
                best_word_indices = np.take(vocab_slice_ids, best_word_indices, axis=0)

            # (4) Normalize the scores of newly finished hypotheses. Note that after this until the
            # next call to topk(), hypotheses may not be in sorted order.
            _sort_inputs = [best_hyp_indices, best_word_indices, finished, scores_accumulated, lengths,
                            estimated_reference_lengths]
            if target_factors is not None:
                _sort_inputs.append(target_factors)
            best_word_indices, finished, scores_accumulated, lengths, estimated_reference_lengths = \
                self._sort_norm_and_update_finished(*_sort_inputs)

            # Collect best hypotheses, best word indices
            best_word_indices_list.append(best_word_indices)
            best_hyp_indices_list.append(best_hyp_indices)

            if self._should_stop(finished, batch_size):
                break

            # (5) update models' state with winning hypotheses (ascending)
            model_states = self._sort_states(best_hyp_indices, *model_states)

        logger.debug("Finished after %d out of %d steps.", t, max_iterations)

        # (9) Sort the hypotheses within each sentence (normalization for finished hyps may have unsorted them).
        scores_accumulated_shape = scores_accumulated.shape
        folded_accumulated_scores = scores_accumulated.reshape((batch_size, -1))
        indices = np.argsort(folded_accumulated_scores.astype('float32', copy=False), axis=1).reshape((-1,))
        best_hyp_indices = np.unravel_index(indices, scores_accumulated_shape)[0].astype('int32') + offset
        scores_accumulated = scores_accumulated.take(best_hyp_indices, axis=0)
        best_hyp_indices_list.append(best_hyp_indices)
        lengths = lengths.take(best_hyp_indices, axis=0)
        all_best_hyp_indices = np.stack(best_hyp_indices_list, axis=1)
        all_best_word_indices = np.stack(best_word_indices_list, axis=2)
        constraints = [constraints[x] for x in best_hyp_indices.tolist()]

        return all_best_hyp_indices, \
               all_best_word_indices, \
               scores_accumulated, \
               lengths.astype('int32', copy=False), \
               estimated_reference_lengths, \
               constraints
Exemplo n.º 10
0
    def forward(self,
                source: np.ndarray,
                source_length: np.ndarray,
                restrict_lexicon: Optional[lexicon.TopKLexicon],
                raw_constraint_list: List[Optional[constrained.RawConstraintList]],
                raw_avoid_list: List[Optional[constrained.RawConstraintList]],
                max_output_lengths: np.ndarray) -> Tuple[np.ndarray,
                                                         np.ndarray,
                                                         np.ndarray,
                                                         np.ndarray,
                                                         List[Optional[np.ndarray]],
                                                         List[Optional[constrained.ConstrainedHypothesis]]]:
        """
        Translates a single sentence (batch_size=1) using greedy search.

        :param source: Source ids. Shape: (batch_size=1, bucket_key, num_factors).
        :param source_length: Valid source lengths. Shape: (batch_size=1,).
        :param restrict_lexicon: Lexicon to use for vocabulary restriction.
        :param raw_constraint_list: A list of optional lists containing phrases (as lists of target word IDs)
                that must appear in each output.
        :param raw_avoid_list: A list of optional lists containing phrases (as lists of target word IDs)
                that must NOT appear in each output.
        :param max_output_lengths: ndarray of maximum output lengths per input in source.
                Shape: (batch_size=1,). Dtype: int32.
        :return List of best hypotheses indices, list of best word indices,
                array of accumulated length-normalized negative log-probs, hypotheses lengths,
                predicted lengths of references (if any), constraints (if any).
        """
        batch_size = source.shape[0]
        assert batch_size == 1, "Greedy Search does not support batch_size != 1"

        # Maximum  search iterations (determined by longest input with eos)
        max_iterations = max_output_lengths.max().item()
        logger.debug("max greedy search iterations: %d", max_iterations)

        # best word_indices (also act as input: (batch*beam, num_target_factors
        best_word_index = np.full((batch_size, self.num_target_factors),
                                  fill_value=self.bos_id, ctx=self.context, dtype='int32')
        outputs = []  # type: List[np.ndarray]

        vocab_slice_ids = None  # type: Optional[np.ndarray]
        # If using a top-k lexicon, select param rows for logit computation that correspond to the
        # target vocab for this sentence.
        if restrict_lexicon:
            source_words = np.squeeze(np.split(source, self.num_source_factors, axis=2)[0], axis=2)
            vocab_slice_ids, _, raw_constraint_list = _get_vocab_slice_ids(restrict_lexicon, source_words,
                                                                           raw_constraint_list,
                                                                           self.eos_id, beam_size=1)

        # (0) encode source sentence, returns a list
        model_states, _ = self._inference.encode_and_initialize(source, source_length)
        # TODO: check for disabled predicted output length

        t = 1
        for t in range(1, max_iterations + 1):
            scores, model_states, target_factors = self._inference.decode_step(best_word_index,
                                                                               model_states,
                                                                               vocab_slice_ids=vocab_slice_ids)
            # shape: (batch*beam=1, 1)
            best_word_index = self.work_block(scores, vocab_slice_ids, target_factors)
            outputs.append(best_word_index)

            if best_word_index == self.eos_id or best_word_index == C.PAD_ID:
                break

        logger.debug("Finished after %d out of %d steps.", t, max_iterations)

        # shape: (1, num_factors, length)
        stacked_outputs = np.stack(outputs, axis=2)
        length = np.array([t], dtype='int32')  # shape (1,)
        hyp_indices = np.zeros((1, t + 1), dtype='int32')
        score = np.array([-1.])  # TODO: return unnormalized proper score

        return hyp_indices, stacked_outputs, score, length, None, []  # type: ignore