Beispiel #1
0
    def _grammar_step(self, logits, next_cell_states, decode_states, actions, gmr_mask):
        """跟进文法约束完成一步解码逻辑

        Args:
            logits (Variable): shape = [batch_size, beam_size, vocab_size]
            next_cell_states (Variable): NULL
            decode_states (StateWrapper): NULL

        Returns: TODO

        Raises: NULL

        """
        # 解码出符合语法规则的 token logits
        logits, valid_table_mask = self._output_layer(logits, actions, gmr_mask, decode_states.valid_table_mask)

        # 初始化 vocab size
        self._vocab_size = logits.shape[-1]
        self._vocab_size_tensor = layers.fill_constant(shape=[1], dtype='int64', value=logits.shape[-1])

        # 计算 log probs,并 mask 掉 finished 部分
        step_log_probs = layers.log(layers.softmax(logits))
        step_log_probs = self._mask_finished_probs(step_log_probs, decode_states.finished)

        scores = layers.reshape(step_log_probs, [-1, self._beam_size * self._vocab_size])
        topk_scores, topk_indices = layers.topk(input=scores, k=self._beam_size)
        topk_scores = layers.reshape(topk_scores, shape=[-1])
        topk_indices = layers.reshape(topk_indices, shape=[-1])

        # top-k 对应的 beam
        beam_indices = layers.elementwise_floordiv(topk_indices, self._vocab_size_tensor)
        # top-k 对应的 token id
        token_indices = layers.elementwise_mod(topk_indices, self._vocab_size_tensor)

        # 根据 top k 的来源,重新组织 step_log_probs
        next_log_probs = nn_utils.batch_gather(
                layers.reshape(step_log_probs, [-1, self._beam_size * self._vocab_size]),
                topk_indices)
        def _beam_gather(x, beam_indices):
            """reshape x to beam dim, and gather each beam_indices
            Args:
                x (TYPE): NULL
            Returns: Variable
            """
            x = self.split_batch_beams(x)
            return nn_utils.batch_gather(x, beam_indices)
        next_cell_states = layers.utils.map_structure(lambda x: _beam_gather(x, beam_indices),
                                                      next_cell_states)
        next_finished = _beam_gather(decode_states.finished, beam_indices)
        next_lens = _beam_gather(decode_states.lengths, beam_indices)

        next_lens = layers.elementwise_add(next_lens,
                layers.cast(layers.logical_not(next_finished), next_lens.dtype))
        next_finished = layers.logical_or(next_finished,
                layers.equal(token_indices, self._end_token_tensor))

        decode_output = OutputWrapper(topk_scores, token_indices, beam_indices)
        decode_states = StateWrapper(next_cell_states, next_log_probs, next_finished, next_lens, valid_table_mask)

        return decode_output, decode_states
Beispiel #2
0
 def _beam_gather(x, beam_indices):
     """reshape x to beam dim, and gather each beam_indices
     Args:
         x (TYPE): NULL
     Returns: Variable
     """
     x = self.split_batch_beams(x)
     return nn_utils.batch_gather(x, beam_indices)
Beispiel #3
0
def _process_type_midd(condition, decoder, grammar_stack, next_inputs,
                       predicted_ids):
    """Process when output type is MID

    Args:
        condition (TYPE): NULL
        decoder (TYPE): NULL
        grammar_stack (TYPE): NULL
        next_inputs (TYPE): NULL
        predicted_ids (TYPE): NULL

    Returns: TODO

    Raises: NULL
    """
    midd_pred_ids = fluider.elementwise_mul(predicted_ids,
                                            condition,
                                            axis=0,
                                            force=True)
    ## get grammar desc
    # 解码结果(语法ID)对应的具体语法规则序列。比如解码结果为 SingleSQL,则对应的语法序列为 Select Filter
    # shape = [batch_size, grammar.max_desc_len]
    gmr_desc = decoder.grammar_desc(midd_pred_ids)
    # 语法规则序列的长度,比如 SingleSQL --> Select Filter, 则长度为2
    # shape = [batch_size, 1]
    gmr_desc_lens = decoder.grammar_desc_lens(midd_pred_ids)
    # shape = [batch_size, 1]
    gmr_desc_pos = tensor.zeros_like(gmr_desc_lens)

    ## generate next grammar mask by first token in desc
    next_output = nn_utils.batch_gather(gmr_desc, gmr_desc_pos)
    next_actions = decoder.grammar_action(next_output)
    next_gmr_mask = decoder.grammar_mask(next_output)

    ## push left grammar tokens to stack
    gmr_stack_tmp, gmr_stack_pos_tmp = _push_to_stack(gmr_desc, gmr_desc_pos,
                                                      gmr_desc_lens,
                                                      grammar_stack)

    ## save result, while condition is True
    new_gmr_stack, new_gmr_stack_pos, new_actions, new_gmr_mask = nn_utils.ifelse(
        condition,
        [gmr_stack_tmp, gmr_stack_pos_tmp, next_actions, next_gmr_mask], [
            grammar_stack.data, grammar_stack.pos, next_inputs.action,
            next_inputs.gmr_mask
        ])
    layers.utils.map_structure(
        layers.assign,
        [new_gmr_stack, new_gmr_stack_pos, new_actions, new_gmr_mask], [
            grammar_stack.data, grammar_stack.pos, next_inputs.action,
            next_inputs.gmr_mask
        ])
Beispiel #4
0
    def pop(cls, stack_data, mask=True, in_place=True):
        """pop data in stack_data

        Args:
            stack_data (StackData): (data, pos) with shape ([batch_size, stack_len], [batch_size, 1])
            mask (bool): 是否 mask 空栈的返回值。默认为 True
            in_place (bool): 默认为 True

        Returns: (Variable1, Variable2)
            Variable1: pop 得到的值
                       dtype=stack_data.data.dtype
                       shape=[-1]
            Variable2: 对应位置的值是否合法。入参已经为空的栈,此处为 False。
                       dtype=bool
                       shape=[-1]
        Raises: NULL
        """
        data = stack_data.data
        pos = stack_data.pos

        # 只有非空的栈才能pop(才合法)
        valid_pos = layers.logical_not(cls.empty(stack_data))
        new_pos_delta = layers.cast(valid_pos, dtype=pos.dtype)
        new_pos = layers.elementwise_sub(pos, new_pos_delta)

        # shape = [batch_size]
        output = nn_utils.batch_gather(data, new_pos)
        # mask 空栈的返回值
        if mask:
            # shape = [batch_size, 1]
            mask_tag = layers.cast(
                new_pos_delta,
                dtype=data.dtype) if data.dtype != pos.dtype else new_pos_delta
            mask_tag = layers.squeeze(mask_tag, [1])
            output = layers.elementwise_mul(output, mask_tag)

        # 出栈后原位置置为0
        updates = layers.zeros_like(output)
        new_data = nn_utils.batch_scatter(data,
                                          new_pos,
                                          updates,
                                          overwrite=True,
                                          in_place=in_place)

        if in_place:
            layers.assign(new_pos, pos)
            return output, valid_pos, stack_data
        else:
            return output, valid_pos, StackData(new_data, new_pos)
Beispiel #5
0
def _select_column(condition,
                   inputs,
                   column_enc,
                   column_len,
                   ptr_net,
                   grammar,
                   column2table_mask,
                   name=None):
    """select_column.

    Args:
        condition (TYPE): NULL
        inputs (Variable): shape = [batch_size, max_len, hidden_size]. infer 阶段 max_len 恒为1
        column_enc (TYPE): NULL
        column_len (TYPE): NULL
        ptr_net (TYPE): NULL
        grammar (TYPE): NULL
        column2table_mask (Variable):
        name (str):

    Returns: TODO

    Raises: NULL
    """
    condition = layers.cast(condition, dtype='float32')

    column_mask = layers.sequence_mask(column_len,
                                       maxlen=grammar.MAX_COLUMN,
                                       dtype='float32')
    column_mask = layers.reshape(column_mask, [-1, grammar.MAX_COLUMN])
    predicts = ptr_net.forward(inputs, column_enc, column_mask)

    pred_ids = layers.argmax(predicts, axis=-1)
    valid_table_mask = nn_utils.batch_gather(column2table_mask, pred_ids)

    ## concat zeros to vocab size
    zeros_l = tensor.fill_constant_batch_size_like(
        predicts,
        shape=[-1, grammar.grammar_size + grammar.MAX_TABLE],
        dtype='float32',
        value=-INF)
    zeros_r = tensor.fill_constant_batch_size_like(
        predicts, shape=[-1, grammar.MAX_VALUE], dtype='float32', value=-INF)
    final_output = tensor.concat([zeros_l, predicts, zeros_r], axis=-1)
    true_final_output = layers.elementwise_mul(final_output, condition, axis=0)
    true_valid_table_mask = layers.elementwise_mul(valid_table_mask,
                                                   condition,
                                                   axis=0)
    return true_final_output, true_valid_table_mask
Beispiel #6
0
    def step_gmr_type(self, gmr_seq, gmr_pos):
        """get type of grammar on gmr_pos of gmr_seq

        Args:
            gmr_seq (TYPE): NULL
            gmr_pos (TYPE): NULL

        Returns: TODO

        Raises: NULL

        """
        gmr_id = nn_utils.batch_gather(gmr_seq, gmr_pos)
        output = self.grammar_type(gmr_id, False)
        return output
Beispiel #7
0
def _push_to_stack(gmr_desc, gmr_pos, gmr_lens, gmr_stack_info):
    """push grammar id in gmr_desc from gmr_pos to gmr_lens to
    gmr_stack. and update step_gmr_pos

    Args:
        gmr_desc (TYPE): NULL
        gmr_pos (TYPE): NULL
        gmr_lens (TYPE): NULL
        gmr_stack_info (tuple): [in/out] (gmr_stack, gmr_stack_pos)

    Returns: tuple (gmr_stack, gmr_stack_pos)

    Raises: NULL
    """
    gmr_stack, gmr_stack_pos = gmr_stack_info
    mv_step = layers.cast(layers.greater_than(gmr_lens,
                                              layers.zeros_like(gmr_lens)),
                          dtype=gmr_lens.dtype)
    gmr_mv_pos = layers.elementwise_sub(gmr_lens, mv_step)

    cond = layers.reduce_any(layers.greater_than(gmr_mv_pos, gmr_pos))
    while_op = layers.While(cond)
    with while_op.block():
        gmr_ids = nn_utils.batch_gather(gmr_desc, gmr_mv_pos)
        gmr_stack_tmp, gmr_stack_pos_tmp = data_structure.Stack.push(
            gmr_stack_info, gmr_ids, in_place=False)

        mv_cond = layers.greater_than(gmr_mv_pos, gmr_pos)
        gmr_mv_pos_tmp = fluider.elementwise_sub(gmr_mv_pos,
                                                 mv_cond,
                                                 force=True)
        new_gmr_stack, new_gmr_stack_pos = nn_utils.ifelse(
            mv_cond, [gmr_stack_tmp, gmr_stack_pos_tmp],
            [gmr_stack, gmr_stack_pos])
        layers.utils.map_structure(layers.assign,
                                   [new_gmr_stack, new_gmr_stack_pos],
                                   [gmr_stack, gmr_stack_pos])
        layers.assign(gmr_mv_pos_tmp, gmr_mv_pos)
        layers.assign(
            layers.reduce_any(layers.greater_than(gmr_mv_pos, gmr_pos)), cond)
    return gmr_stack, gmr_stack_pos