Ejemplo n.º 1
0
def train(bert_config, data_reader):
    with fluid.dygraph.guard(place):
        fluid.default_main_program().random_seed = SEED
        fluid.default_startup_program().random_seed = SEED

        data_loader = fluid.io.DataLoader.from_generator(capacity=50,
                                                         iterable=True)
        data_loader.set_batch_generator(data_reader.data_generator(),
                                        places=place)

        bert = PretrainModelLayer(config=bert_config,
                                  weight_sharing=False,
                                  use_fp16=False)

        optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters())
        step_idx = 0
        speed_list = []
        for input_data in data_loader():
            src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = input_data
            next_sent_acc, mask_lm_loss, total_loss = bert(
                src_ids=src_ids,
                position_ids=pos_ids,
                sentence_ids=sent_ids,
                input_mask=input_mask,
                mask_label=mask_label,
                mask_pos=mask_pos,
                labels=labels)
            total_loss.backward()
            optimizer.minimize(total_loss)
            bert.clear_gradients()

            acc = np.mean(np.array(next_sent_acc.numpy()))
            loss = np.mean(np.array(total_loss.numpy()))
            ppl = np.mean(np.exp(np.array(mask_lm_loss.numpy())))

            if step_idx % PRINT_STEP == 0:
                if step_idx == 0:
                    print("Step: %d, loss: %f, ppl: %f, next_sent_acc: %f" %
                          (step_idx, loss, ppl, acc))
                    avg_batch_time = time.time()
                else:
                    speed = PRINT_STEP / (time.time() - avg_batch_time)
                    speed_list.append(speed)
                    print(
                        "Step: %d, loss: %f, ppl: %f, next_sent_acc: %f, speed: %.3f steps/s"
                        % (step_idx, loss, ppl, acc, speed))
                    avg_batch_time = time.time()

            step_idx += 1
            if step_idx == STEP_NUM:
                break
        return loss, ppl
Ejemplo n.º 2
0
def predict_dygraph(bert_config, data):
    program_translator.enable(False)
    with fluid.dygraph.guard(place):
        bert = PretrainModelLayer(config=bert_config,
                                  weight_sharing=False,
                                  use_fp16=False)
        model_dict, _ = fluid.dygraph.load_dygraph(DY_STATE_DICT_SAVE_PATH)

        bert.set_dict(model_dict)
        bert.eval()

        input_vars = [fluid.dygraph.to_variable(x) for x in data]
        src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = input_vars
        pred_res = bert(src_ids=src_ids,
                        position_ids=pos_ids,
                        sentence_ids=sent_ids,
                        input_mask=input_mask,
                        mask_label=mask_label,
                        mask_pos=mask_pos,
                        labels=labels)
        pred_res = [var.numpy() for var in pred_res]

        return pred_res