示例#1
0
    def decode_forward_with_state(self, last_encoded, all_encoded, mask,
                                  input_data, state, new_seq):
        if new_seq:
            last_encoded.copyto(self.init_state_executor.arg_dict["encoded"])
            self.init_state_executor.forward()
            init_hs = self.init_state_executor.outputs[0]
            # init_hs.copyto(self.decode_executor.arg_dict["target_l0_init_h"])
            self.decode_executor.arg_dict["target_l0_init_c"][:] = 0.0
            state = LSTMState(
                c=self.decode_executor.arg_dict["target_l0_init_c"], h=init_hs)
            all_encoded.copyto(self.decode_executor.arg_dict["attended"])
            mask.copyto(self.decode_executor.arg_dict["encoded_mask"])
        input_data.copyto(self.decode_executor.arg_dict["target"])
        state.c.copyto(self.decode_executor.arg_dict["target_l0_init_c"])
        state.h.copyto(self.decode_executor.arg_dict["target_l0_init_h"])
        self.decode_executor.forward()

        prob = self.decode_executor.outputs[0]

        c = self.decode_executor.outputs[1]
        h = self.decode_executor.outputs[2]

        attention_weights = self.decode_executor.outputs[3]

        return prob, attention_weights, LSTMState(c=c, h=h)
示例#2
0
def lstm_attention_decode_symbol(t_num_lstm_layer, t_seq_len, t_vocab_size,
                                 t_num_hidden, t_num_embed, t_num_label,
                                 t_dropout, attention, source_seq_len):
    data = mx.sym.Variable("target")
    seqidx = 0

    embed_weight = mx.sym.Variable("target_embed_weight")
    cls_weight = mx.sym.Variable("target_cls_weight")
    cls_bias = mx.sym.Variable("target_cls_bias")

    input_weight = mx.sym.Variable("target_input_weight")
    # input_bias = mx.sym.Variable("target_input_bias")

    param_cells = []
    last_states = []

    for i in range(t_num_lstm_layer):
        param_cells.append(
            LSTMParam(i2h_weight=mx.sym.Variable("target_l%d_i2h_weight" % i),
                      i2h_bias=mx.sym.Variable("target_l%d_i2h_bias" % i),
                      h2h_weight=mx.sym.Variable("target_l%d_h2h_weight" % i),
                      h2h_bias=mx.sym.Variable("target_l%d_h2h_bias" % i)))
        state = LSTMState(c=mx.sym.Variable("target_l%d_init_c" % i),
                          h=mx.sym.Variable("target_l%d_init_h" % i))
        # state = LSTMState(c=mx.sym.Variable("target_l%d_init_c" % i),
        #                   h=init_hs[i])
        last_states.append(state)
    assert (len(last_states) == t_num_lstm_layer)

    hidden = mx.sym.Embedding(data=data,
                              input_dim=t_vocab_size,
                              output_dim=t_num_embed,
                              weight=embed_weight,
                              name="target_embed")

    all_encoded = mx.sym.Variable("attended")
    encoded = mx.sym.SliceChannel(data=all_encoded,
                                  axis=1,
                                  num_outputs=source_seq_len)
    weights, weighted_encoded = attention.attend(attended=encoded,
                                                 concat_attended=all_encoded,
                                                 state=last_states[0].h,
                                                 attend_masks=None,
                                                 use_masking=False)
    con = mx.sym.Concat(hidden, weighted_encoded)
    hidden = mx.sym.FullyConnected(data=con,
                                   num_hidden=t_num_embed,
                                   weight=input_weight,
                                   no_bias=True,
                                   name='input_fc')
    # hidden = mx.sym.Activation(data=hidden, act_type='tanh', name='input_act')

    # stack LSTM
    for i in range(t_num_lstm_layer):
        if i == 0:
            dp = 0.
        else:
            dp = t_dropout
        next_state = lstm(t_num_hidden,
                          indata=hidden,
                          prev_state=last_states[i],
                          param=param_cells[i],
                          seqidx=seqidx,
                          layeridx=i,
                          dropout=dp)
        hidden = next_state.h
        last_states[i] = next_state

    fc = mx.sym.FullyConnected(data=hidden,
                               num_hidden=t_num_label,
                               weight=cls_weight,
                               bias=cls_bias,
                               name='target_pred')
    sm = mx.sym.SoftmaxOutput(data=fc, name='target_softmax')
    output = [sm]
    for state in last_states:
        output.append(state.c)
        output.append(state.h)
    output.append(weights)
    return mx.sym.Group(output)
示例#3
0
def translate_one_with_beam(max_decode_len, sentence, model_buckets,
                            unroll_len, source_vocab, target_vocab,
                            revert_vocab, target_ndarray, beam_size,
                            eos_index):
    input_length = len(sentence)
    cur_model = get_bucket_model(model_buckets, input_length)
    input_ndarray = mx.nd.zeros((beam_size, unroll_len))
    mask_ndarray = mx.nd.zeros((beam_size, unroll_len))

    beam = [[
        BeamNode(father=-1,
                 content='<s>',
                 score=0.0,
                 acc_score=0.0,
                 finish=False,
                 finishLen=0) for i in range(beam_size)
    ]]
    beam_state = [None]

    MakeInput_beam(sentence, source_vocab, unroll_len, input_ndarray,
                   mask_ndarray, beam_size)
    last_encoded, all_encoded = cur_model.encode(
        input_ndarray,
        mask_ndarray)  # last_encoded means the last time step hidden
    for i in range(max_decode_len):
        MakeTargetInput_beam(beam[-1], target_vocab, target_ndarray)
        prob, attention_weights, new_state = cur_model.decode_forward_with_state(
            last_encoded, all_encoded, mask_ndarray, target_ndarray,
            beam_state[-1], i == 0)
        log_prob = -mx.ndarray.log(prob)
        finished_beam = [t for t, x in enumerate(beam[-1]) if x.finish]
        for idx in range(beam_size):
            # log_prob[idx] = mx.nd.add(log_prob[idx], beam[-1][idx].score)
            if not beam[-1][idx].finish:
                # log_prob[idx] += beam[-1][idx].acc_score
                log_prob[idx] = (log_prob[idx] + beam[-1][idx].acc_score *
                                 beam[-1][idx].finishLen) / (
                                     beam[-1][idx].finishLen + 1)
            else:
                # log_prob[idx] = beam[-1][idx].acc_score
                log_prob[idx] = beam[-1][idx].acc_score
        for idx in finished_beam:
            log_prob[idx][:eos_index] = np.inf
            log_prob[idx][eos_index + 1:] = np.inf

        (indexes, outputs), chosen_costs = _smallest(log_prob.asnumpy(),
                                                     beam_size,
                                                     only_first_row=(i == 0))
        next_chars = [
            revert_vocab[idx] if idx in revert_vocab else '' for idx in outputs
        ]

        next_state_h = mx.nd.empty(new_state.h.shape, ctx=mx.gpu(0))
        next_state_c = mx.nd.empty(new_state.c.shape, ctx=mx.gpu(0))
        for idx in range(beam_size):
            next_state_h[idx] = new_state.h[np.asscalar(indexes[idx])]
            next_state_c[idx] = new_state.c[np.asscalar(indexes[idx])]
        next_state = LSTMState(c=next_state_c, h=next_state_h)
        beam_state.append(next_state)

        next_beam = [
            BeamNode(
                father=indexes[idx],
                content=next_chars[idx] if not beam[-1][indexes[idx]].finish
                else beam[-1][indexes[idx]].content,
                score=chosen_costs[idx] - beam[-1][indexes[idx]].acc_score,
                acc_score=chosen_costs[idx],
                finish=(next_chars[idx] == '</s>'
                        or beam[-1][indexes[idx]].finish),
                finishLen=(beam[-1][indexes[idx]].finishLen
                           if beam[-1][indexes[idx]].finish else
                           (beam[-1][indexes[idx]].finishLen + 1)))
            for idx in range(beam_size)
        ]
        beam.append(next_beam)
        finished = [node.finish for node in beam[-1]]
        if all(finished):
            break
            # output.append(next_char)
    all_result = []
    all_score = []
    for aaa in range(beam_size):
        ptr = aaa
        result = []

        for idx in range(len(beam) - 1 - 1, 0, -1):
            word = beam[idx][ptr].content
            if word != '</s>':
                result.append(word)
            ptr = beam[idx][ptr].father
        result = result[::-1]
        all_result.append(' '.join(result))
        all_score.append(beam[-1][aaa].acc_score)

    return all_result, all_score