Ejemplo n.º 1
0
    def get_metrics(self, inputs, outputs):
        """Get metrics."""
        metrics = {}
        tgt_logits = self._calc_logits(outputs["enc_out"], inputs["tgt_idx"])
        lm_loss = layers.softmax_with_cross_entropy(logits=tgt_logits, label=inputs["tgt_label"])
        need_cal = layers.not_equal(
            inputs["tgt_label"], layers.fill_constant(shape=[1], dtype="int64", value=1)
        )
        need_cal = layers.cast(need_cal, self.dtype)
        mean_lm_loss = layers.reduce_sum(lm_loss * need_cal) / (layers.reduce_sum(need_cal) + 1e-10)

        pooled_out = self._get_pooled_output(outputs["enc_out"], inputs["label_idx"])
        nsp_logits = self._get_classifier_output(pooled_out, name="next_sent")
        nsp_loss, nsp_softmax = layers.softmax_with_cross_entropy(
            logits=nsp_logits, label=inputs["label"], return_softmax=True)

        nsp_acc = layers.accuracy(nsp_softmax, inputs["label"])
        mean_nsp_loss = layers.mean(nsp_loss)

        loss = mean_nsp_loss
        if self.use_mlm:
            loss = loss + mean_lm_loss
            metrics["token_lm_loss"] = mean_lm_loss
        metrics["loss"] = loss
        metrics["nsp_loss"] = mean_nsp_loss
        metrics["nsp_acc"] = nsp_acc
        return metrics
Ejemplo n.º 2
0
    def _get_metrics(self, inputs, outputs):
        metrics = {}
        fc_out = self._calc_logits(enc_out=outputs["enc_out"],
                                   seq_pos=inputs["tgt_pos"])
        #fc_out = self._calc_logits(outputs["enc_out"], outputs["checkpoints"], inputs["tgt_pos"])
        lm_loss = layers.softmax_with_cross_entropy(logits=fc_out,
                                                    label=inputs["tgt_pos"])
        need_cal = layers.not_equal(
            inputs["tgt_label"],
            layers.fill_constant(shape=[1], dtype="int64", value=1))
        need_cal = layers.cast(need_cal, self.dtype)
        mean_lm_loss = layers.reduce_sum(
            lm_loss * need_cal) / (layers.reduce_sum(need_cal) + 1e-10)

        pooled_out = self._get_pooled_output(outputs["enc_out"],
                                             inputs["label_pos"])
        nsp_fc_out = layers.fc(input=pooled_out,
                               size=2,
                               param_attr=fluid.ParamAttr(
                                   name="next_sent_fc.w_0",
                                   initializer=self.param_initializer),
                               bias_attr="next_sent_fc.b_0")
        nsp_loss, nsp_softmax = layers.softmax_with_cross_entropy(
            logits=nsp_fc_out, label=inputs["label"], return_softmax=True)

        nsp_acc = layers.accuracy(nsp_softmax, inputs["label"])
        mean_nsp_loss = layers.mean(nsp_loss)

        metrics["loss"] = mean_lm_loss + mean_nsp_loss
        metrics["lm_loss"] = mean_lm_loss
        metrics["nsp_loss"] = mean_nsp_loss
        metrics["nsp_acc"] = nsp_acc
        return metrics
Ejemplo n.º 3
0
    def _build_decoder(self,
                       enc_last_hidden,
                       enc_last_cell,
                       mode='train',
                       beam_size=10):
        softmax_weight = layers.create_parameter([self.hidden_size, self.tar_vocab_size], dtype="float32", name="softmax_weight", \
                    default_initializer=fluid.initializer.UniformInitializer(low=-self.init_scale, high=self.init_scale))
        if mode == 'train':
            dec_output, dec_last_hidden, dec_last_cell = basic_lstm( self.tar_emb, enc_last_hidden, enc_last_cell, \
                    self.hidden_size, num_layers=self.num_layers, \
                    batch_first=self.batch_first, \
                    dropout_prob=self.dropout, \
                    param_attr = ParamAttr( initializer=fluid.initializer.UniformInitializer(low=-self.init_scale, high=self.init_scale) ), \
                    bias_attr = ParamAttr( initializer = fluid.initializer.Constant(0.0) ))

            dec_output = layers.matmul(dec_output, softmax_weight)

            return dec_output
        elif mode == 'beam_search' or mode == 'greedy_search':
            dec_unit_list = []
            name = 'basic_lstm'
            for i in range(self.num_layers):
                new_name = name + "_layers_" + str(i)
                dec_unit_list.append(
                    BasicLSTMUnit(new_name, self.hidden_size, dtype='float32'))

            def decoder_step(current_in, pre_hidden_array, pre_cell_array):
                new_hidden_array = []
                new_cell_array = []

                step_in = current_in
                for i in range(self.num_layers):
                    pre_hidden = pre_hidden_array[i]
                    pre_cell = pre_cell_array[i]

                    new_hidden, new_cell = dec_unit_list[i](step_in,
                                                            pre_hidden,
                                                            pre_cell)

                    new_hidden_array.append(new_hidden)
                    new_cell_array.append(new_cell)

                    step_in = new_hidden

                return step_in, new_hidden_array, new_cell_array

            if mode == 'beam_search':
                max_src_seq_len = layers.shape(self.src)[1]
                max_length = max_src_seq_len * 2
                #max_length = layers.fill_constant( [1], dtype='int32', value = 10)
                pre_ids = layers.fill_constant([1, 1], dtype='int64', value=1)
                full_ids = layers.fill_constant([1, 1], dtype='int64', value=1)

                score = layers.fill_constant([1], dtype='float32', value=0.0)

                #eos_ids = layers.fill_constant( [1, 1], dtype='int64', value=2)

                pre_hidden_array = []
                pre_cell_array = []
                pre_feed = layers.fill_constant([beam_size, self.hidden_size],
                                                dtype='float32',
                                                value=0)
                for i in range(self.num_layers):
                    pre_hidden_array.append(
                        layers.expand(enc_last_hidden[i], [beam_size, 1]))
                    pre_cell_array.append(
                        layers.expand(enc_last_cell[i], [beam_size, 1]))

                eos_ids = layers.fill_constant([beam_size],
                                               dtype='int64',
                                               value=2)
                init_score = np.zeros((beam_size)).astype('float32')
                init_score[1:] = -INF
                pre_score = layers.assign(init_score)
                #pre_score = layers.fill_constant( [1,], dtype='float32', value= 0.0)
                tokens = layers.fill_constant([beam_size, 1],
                                              dtype='int64',
                                              value=1)

                enc_memory = layers.expand(self.enc_output, [beam_size, 1, 1])

                pre_tokens = layers.fill_constant([beam_size, 1],
                                                  dtype='int64',
                                                  value=1)

                finished_seq = layers.fill_constant([beam_size, 1],
                                                    dtype='int64',
                                                    value=0)
                finished_scores = layers.fill_constant([beam_size],
                                                       dtype='float32',
                                                       value=-INF)
                finished_flag = layers.fill_constant([beam_size],
                                                     dtype='float32',
                                                     value=0.0)

                step_idx = layers.fill_constant(shape=[1],
                                                dtype='int32',
                                                value=0)
                cond = layers.less_than(x=step_idx,
                                        y=max_length)  # default force_cpu=True

                parent_idx = layers.fill_constant([1], dtype='int32', value=0)
                while_op = layers.While(cond)

                def compute_topk_scores_and_seq(sequences,
                                                scores,
                                                scores_to_gather,
                                                flags,
                                                beam_size,
                                                select_beam=None,
                                                generate_id=None):
                    scores = layers.reshape(scores, shape=[1, -1])
                    _, topk_indexs = layers.topk(scores, k=beam_size)

                    topk_indexs = layers.reshape(topk_indexs, shape=[-1])

                    # gather result

                    top_seq = layers.gather(sequences, topk_indexs)
                    topk_flags = layers.gather(flags, topk_indexs)
                    topk_gather_scores = layers.gather(scores_to_gather,
                                                       topk_indexs)

                    if select_beam:
                        topk_beam = layers.gather(select_beam, topk_indexs)
                    else:
                        topk_beam = select_beam

                    if generate_id:
                        topk_id = layers.gather(generate_id, topk_indexs)
                    else:
                        topk_id = generate_id
                    return top_seq, topk_gather_scores, topk_flags, topk_beam, topk_id

                def grow_alive(curr_seq, curr_scores, curr_log_probs,
                               curr_finished, select_beam, generate_id):
                    curr_scores += curr_finished * -INF
                    return compute_topk_scores_and_seq(curr_seq,
                                                       curr_scores,
                                                       curr_log_probs,
                                                       curr_finished,
                                                       beam_size,
                                                       select_beam,
                                                       generate_id=generate_id)

                def grow_finished(finished_seq, finished_scores, finished_flag,
                                  curr_seq, curr_scores, curr_finished):
                    finished_seq = layers.concat([
                        finished_seq,
                        layers.fill_constant(
                            [beam_size, 1], dtype='int64', value=1)
                    ],
                                                 axis=1)
                    curr_scores += (1.0 - curr_finished) * -INF
                    #layers.Print( curr_scores, message="curr scores")
                    curr_finished_seq = layers.concat([finished_seq, curr_seq],
                                                      axis=0)
                    curr_finished_scores = layers.concat(
                        [finished_scores, curr_scores], axis=0)
                    curr_finished_flags = layers.concat(
                        [finished_flag, curr_finished], axis=0)

                    return compute_topk_scores_and_seq(curr_finished_seq,
                                                       curr_finished_scores,
                                                       curr_finished_scores,
                                                       curr_finished_flags,
                                                       beam_size)

                def is_finished(alive_log_prob, finished_scores,
                                finished_in_finished):

                    max_out_len = 200
                    max_length_penalty = layers.pow(
                        layers.fill_constant([1],
                                             dtype='float32',
                                             value=((5.0 + max_out_len) /
                                                    6.0)), alpha)

                    lower_bound_alive_score = layers.slice(
                        alive_log_prob, starts=[0], ends=[1],
                        axes=[0]) / max_length_penalty

                    lowest_score_of_fininshed_in_finished = finished_scores * finished_in_finished
                    lowest_score_of_fininshed_in_finished += (
                        1.0 - finished_in_finished) * -INF
                    lowest_score_of_fininshed_in_finished = layers.reduce_min(
                        lowest_score_of_fininshed_in_finished)

                    met = layers.less_than(
                        lower_bound_alive_score,
                        lowest_score_of_fininshed_in_finished)
                    met = layers.cast(met, 'float32')
                    bound_is_met = layers.reduce_sum(met)

                    finished_eos_num = layers.reduce_sum(finished_in_finished)

                    finish_cond = layers.less_than(
                        finished_eos_num,
                        layers.fill_constant([1],
                                             dtype='float32',
                                             value=beam_size))

                    return finish_cond

                def grow_top_k(step_idx, alive_seq, alive_log_prob,
                               parant_idx):
                    pre_ids = alive_seq

                    dec_step_emb = layers.embedding(
                        input=pre_ids,
                        size=[self.tar_vocab_size, self.hidden_size],
                        dtype='float32',
                        is_sparse=False,
                        param_attr=fluid.ParamAttr(
                            name='target_embedding',
                            initializer=fluid.initializer.UniformInitializer(
                                low=-self.init_scale, high=self.init_scale)))

                    dec_att_out, new_hidden_array, new_cell_array = decoder_step(
                        dec_step_emb, pre_hidden_array, pre_cell_array)

                    projection = layers.matmul(dec_att_out, softmax_weight)

                    logits = layers.softmax(projection)
                    current_log = layers.elementwise_add(x=layers.log(logits),
                                                         y=alive_log_prob,
                                                         axis=0)
                    base_1 = layers.cast(step_idx, 'float32') + 6.0
                    base_1 /= 6.0
                    length_penalty = layers.pow(base_1, alpha)

                    len_pen = layers.pow(
                        ((5. + layers.cast(step_idx + 1, 'float32')) / 6.),
                        alpha)

                    current_log = layers.reshape(current_log, shape=[1, -1])

                    current_log = current_log / length_penalty
                    topk_scores, topk_indices = layers.topk(input=current_log,
                                                            k=beam_size)

                    topk_scores = layers.reshape(topk_scores, shape=[-1])

                    topk_log_probs = topk_scores * length_penalty

                    generate_id = layers.reshape(
                        topk_indices, shape=[-1]) % self.tar_vocab_size

                    selected_beam = layers.reshape(
                        topk_indices, shape=[-1]) // self.tar_vocab_size

                    topk_finished = layers.equal(generate_id, eos_ids)

                    topk_finished = layers.cast(topk_finished, 'float32')

                    generate_id = layers.reshape(generate_id, shape=[-1, 1])

                    pre_tokens_list = layers.gather(tokens, selected_beam)

                    full_tokens_list = layers.concat(
                        [pre_tokens_list, generate_id], axis=1)


                    return full_tokens_list, topk_log_probs, topk_scores, topk_finished, selected_beam, generate_id, \
                            dec_att_out, new_hidden_array, new_cell_array

                with while_op.block():
                    topk_seq, topk_log_probs, topk_scores, topk_finished, topk_beam, topk_generate_id, attention_out, new_hidden_array, new_cell_array = \
                        grow_top_k(  step_idx, pre_tokens, pre_score, parent_idx)
                    alive_seq, alive_log_prob, _, alive_beam, alive_id = grow_alive(
                        topk_seq, topk_scores, topk_log_probs, topk_finished,
                        topk_beam, topk_generate_id)

                    finished_seq_2, finished_scores_2, finished_flags_2, _, _ = grow_finished(
                        finished_seq, finished_scores, finished_flag, topk_seq,
                        topk_scores, topk_finished)

                    finished_cond = is_finished(alive_log_prob,
                                                finished_scores_2,
                                                finished_flags_2)

                    layers.increment(x=step_idx, value=1.0, in_place=True)

                    layers.assign(alive_beam, parent_idx)
                    layers.assign(alive_id, pre_tokens)
                    layers.assign(alive_log_prob, pre_score)
                    layers.assign(alive_seq, tokens)
                    layers.assign(finished_seq_2, finished_seq)
                    layers.assign(finished_scores_2, finished_scores)
                    layers.assign(finished_flags_2, finished_flag)

                    # update init_hidden, init_cell, input_feed
                    new_feed = layers.gather(attention_out, parent_idx)
                    layers.assign(new_feed, pre_feed)
                    for i in range(self.num_layers):
                        new_hidden_var = layers.gather(new_hidden_array[i],
                                                       parent_idx)
                        layers.assign(new_hidden_var, pre_hidden_array[i])
                        new_cell_var = layers.gather(new_cell_array[i],
                                                     parent_idx)
                        layers.assign(new_cell_var, pre_cell_array[i])

                    length_cond = layers.less_than(x=step_idx, y=max_length)
                    layers.logical_and(x=length_cond,
                                       y=finished_cond,
                                       out=cond)

                tokens_with_eos = tokens

                all_seq = layers.concat([tokens_with_eos, finished_seq],
                                        axis=0)
                all_score = layers.concat([pre_score, finished_scores], axis=0)
                _, topk_index = layers.topk(all_score, k=beam_size)
                topk_index = layers.reshape(topk_index, shape=[-1])
                final_seq = layers.gather(all_seq, topk_index)
                final_score = layers.gather(all_score, topk_index)

                return final_seq
            elif mode == 'greedy_search':
                max_src_seq_len = layers.shape(self.src)[1]
                max_length = max_src_seq_len * 2
                #max_length = layers.fill_constant( [1], dtype='int32', value = 10)
                pre_ids = layers.fill_constant([1, 1], dtype='int64', value=1)
                full_ids = layers.fill_constant([1, 1], dtype='int64', value=1)

                score = layers.fill_constant([1], dtype='float32', value=0.0)

                eos_ids = layers.fill_constant([1, 1], dtype='int64', value=2)

                pre_hidden_array = []
                pre_cell_array = []
                pre_feed = layers.fill_constant([1, self.hidden_size],
                                                dtype='float32',
                                                value=0)
                for i in range(self.num_layers):
                    pre_hidden_array.append(enc_last_hidden[i])
                    pre_cell_array.append(enc_last_cell[i])
                    #pre_hidden_array.append( layers.fill_constant( [1, hidden_size], dtype='float32', value=0)  )
                    #pre_cell_array.append( layers.fill_constant( [1, hidden_size], dtype='float32', value=0) )

                step_idx = layers.fill_constant(shape=[1],
                                                dtype='int32',
                                                value=0)
                cond = layers.less_than(x=step_idx,
                                        y=max_length)  # default force_cpu=True
                while_op = layers.While(cond)

                with while_op.block():

                    dec_step_emb = layers.embedding(
                        input=pre_ids,
                        size=[self.tar_vocab_size, self.hidden_size],
                        dtype='float32',
                        is_sparse=False,
                        param_attr=fluid.ParamAttr(
                            name='target_embedding',
                            initializer=fluid.initializer.UniformInitializer(
                                low=-self.init_scale, high=self.init_scale)))

                    dec_att_out, new_hidden_array, new_cell_array = decoder_step(
                        dec_step_emb, pre_hidden_array, pre_cell_array)

                    projection = layers.matmul(dec_att_out, softmax_weight)

                    logits = layers.softmax(projection)
                    logits = layers.log(logits)

                    current_log = layers.elementwise_add(logits, score, axis=0)

                    topk_score, topk_indices = layers.topk(input=current_log,
                                                           k=1)

                    new_ids = layers.concat([full_ids, topk_indices])
                    layers.assign(new_ids, full_ids)
                    #layers.Print( full_ids, message="ful ids")
                    layers.assign(topk_score, score)
                    layers.assign(topk_indices, pre_ids)
                    layers.assign(dec_att_out, pre_feed)
                    for i in range(self.num_layers):
                        layers.assign(new_hidden_array[i], pre_hidden_array[i])
                        layers.assign(new_cell_array[i], pre_cell_array[i])

                    layers.increment(x=step_idx, value=1.0, in_place=True)

                    eos_met = layers.not_equal(topk_indices, eos_ids)
                    length_cond = layers.less_than(x=step_idx, y=max_length)
                    layers.logical_and(x=length_cond, y=eos_met, out=cond)

                return full_ids

            raise Exception("error")
        else:
            print("mode not supprt", mode)