Ejemplo n.º 1
0
def infer(args):
    paddle.set_device(args.device)

    # create dataset.
    infer_dataset = LacDataset(args.data_dir, mode='infer')

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=0, dtype='int64'),  # word_ids
        Stack(dtype='int64'),  # length
    ): fn(samples)

    # Create sampler for dataloader
    infer_sampler = paddle.io.BatchSampler(
        dataset=infer_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False)
    infer_loader = paddle.io.DataLoader(
        dataset=infer_dataset,
        batch_sampler=infer_sampler,
        return_list=True,
        collate_fn=batchify_fn)

    # Define the model network
    network = BiGruCrf(args.emb_dim, args.hidden_size, infer_dataset.vocab_size,
                       infer_dataset.num_labels)
    inputs = InputSpec(shape=(-1, ), dtype="int64", name='inputs')
    lengths = InputSpec(shape=(-1, ), dtype="int64", name='lengths')
    model = paddle.Model(network, inputs=[inputs, lengths])
    model.prepare()

    # Load the model and start predicting
    model.load(args.init_checkpoint)
    emissions, lengths, crf_decodes = model.predict(
        test_data=infer_loader, batch_size=args.batch_size)

    # Post-processing the lexical analysis results
    lengths = np.array([l for lens in lengths for l in lens]).reshape([-1])
    preds = np.array(
        [pred for batch_pred in crf_decodes for pred in batch_pred])

    results = parse_lac_result(infer_dataset.word_ids, preds, lengths,
                               infer_dataset.word_vocab,
                               infer_dataset.label_vocab)

    sent_tags = []
    for sent, tags in results:
        sent_tag = ['(%s, %s)' % (ch, tag) for ch, tag in zip(sent, tags)]
        sent_tags.append(''.join(sent_tag))

    file_path = "results.txt"
    with open(file_path, "w", encoding="utf8") as fout:
        fout.write("\n".join(sent_tags))

    # Print some examples
    print(
        "The results have been saved in the file: %s, some examples are shown below: "
        % file_path)
    print("\n".join(sent_tags[:10]))
Ejemplo n.º 2
0
def evaluate(args):
    paddle.set_device(args.device)

    # create dataset.
    test_ds = load_dataset(datafiles=(os.path.join(args.data_dir, 'test.tsv')))
    word_vocab = load_vocab(os.path.join(args.data_dir, 'word.dic'))
    label_vocab = load_vocab(os.path.join(args.data_dir, 'tag.dic'))
    # q2b.dic is used to replace DBC case to SBC case
    normlize_vocab = load_vocab(os.path.join(args.data_dir, 'q2b.dic'))

    trans_func = partial(
        convert_example,
        max_seq_len=args.max_seq_len,
        word_vocab=word_vocab,
        label_vocab=label_vocab,
        normlize_vocab=normlize_vocab)
    test_ds.map(trans_func)

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=0, dtype='int64'),  # word_ids
        Stack(dtype='int64'),  # length
        Pad(axis=0, pad_val=0, dtype='int64'),  # label_ids
    ): fn(samples)

    # Create sampler for dataloader
    test_sampler = paddle.io.BatchSampler(
        dataset=test_ds,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False)
    test_loader = paddle.io.DataLoader(
        dataset=test_ds,
        batch_sampler=test_sampler,
        return_list=True,
        collate_fn=batchify_fn)

    # Define the model network and metric evaluator
    model = BiGruCrf(args.emb_dim, args.hidden_size,
                     len(word_vocab), len(label_vocab))
    chunk_evaluator = ChunkEvaluator(label_list=label_vocab.keys(), suffix=True)

    # Load the model and start predicting
    model_dict = paddle.load(args.init_checkpoint)
    model.load_dict(model_dict)

    model.eval()
    chunk_evaluator.reset()
    for batch in test_loader:
        token_ids, length, labels = batch
        preds = model(token_ids, length)
        num_infer_chunks, num_label_chunks, num_correct_chunks = chunk_evaluator.compute(
            length, preds, labels)
        chunk_evaluator.update(num_infer_chunks.numpy(),
                               num_label_chunks.numpy(),
                               num_correct_chunks.numpy())
        precision, recall, f1_score = chunk_evaluator.accumulate()
    print("eval precision: %f, recall: %f, f1: %f" %
          (precision, recall, f1_score))
Ejemplo n.º 3
0
def evaluate(args):
    place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
    paddle.set_device("gpu" if args.use_gpu else "cpu")

    # create dataset.
    test_dataset = LacDataset(args.data_dir, mode='test')
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=0),  # word_ids
        Stack(),  # length
        Pad(axis=0, pad_val=0),  # label_ids
    ): fn(samples)

    # Create sampler for dataloader
    test_sampler = paddle.io.BatchSampler(dataset=test_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          drop_last=False)
    test_loader = paddle.io.DataLoader(dataset=test_dataset,
                                       batch_sampler=test_sampler,
                                       places=place,
                                       return_list=True,
                                       collate_fn=batchify_fn)

    # Define the model network and metric evaluator
    network = BiGruCrf(args.emb_dim, args.hidden_size, test_dataset.vocab_size,
                       test_dataset.num_labels)
    inputs = InputSpec(shape=(-1, ), dtype="int16", name='inputs')
    lengths = InputSpec(shape=(-1, ), dtype="int16", name='lengths')
    model = paddle.Model(network, inputs=[inputs, lengths])
    chunk_evaluator = ChunkEvaluator(
        label_list=test_dataset.label_vocab.keys(), suffix=True)
    model.prepare(None, None, chunk_evaluator)

    # Load the model and start predicting
    model.load(args.init_checkpoint)
    model.evaluate(
        eval_data=test_loader,
        batch_size=args.batch_size,
        log_freq=100,
        verbose=2,
    )
Ejemplo n.º 4
0
def evaluate(args):
    place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
    paddle.set_device("gpu" if args.use_gpu else "cpu")

    # create dataset.
    test_dataset = LacDataset(args.data_dir, mode='test')
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=0),  # word_ids
        Stack(),  # length
        Pad(axis=0, pad_val=0),  # label_ids
    ): fn(samples)

    # Create sampler for dataloader
    test_sampler = paddle.io.BatchSampler(dataset=test_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          drop_last=True)
    test_loader = paddle.io.DataLoader(dataset=test_dataset,
                                       batch_sampler=test_sampler,
                                       places=place,
                                       return_list=True,
                                       collate_fn=batchify_fn)

    # Define the model network and metric evaluator
    network = BiGruCrf(args.emb_dim, args.hidden_size, test_dataset.vocab_size,
                       test_dataset.num_labels)
    model = paddle.Model(network)
    chunk_evaluator = ChunkEvaluator(
        int(math.ceil((test_dataset.num_labels + 1) / 2.0)),
        "IOB")  # + 1 for SOS and EOS
    model.prepare(None, None, chunk_evaluator)

    # Load the model and start predicting
    model.load(args.init_checkpoint)
    model.evaluate(
        eval_data=test_loader,
        batch_size=args.batch_size,
        log_freq=100,
        verbose=2,
    )
Ejemplo n.º 5
0
def main():
    word_vocab = load_vocab(os.path.join(args.data_dir, 'word.dic'))
    label_vocab = load_vocab(os.path.join(args.data_dir, 'tag.dic'))

    model = BiGruCrf(args.emb_dim, args.hidden_size,
                     len(word_vocab), len(label_vocab))

    state_dict = paddle.load(args.params_path)
    model.set_dict(state_dict)
    model.eval()

    model = paddle.jit.to_static(
        model,
        input_spec=[
            InputSpec(
                shape=[None, None], dtype="int64", name='token_ids'), InputSpec(
                    shape=[None], dtype="int64", name='length')
        ])
    # Save in static graph model.
    paddle.jit.save(model, args.output_path)
Ejemplo n.º 6
0
def train(args):
    paddle.set_device("gpu" if args.n_gpu else "cpu")

    # create dataset.
    train_dataset = LacDataset(args.data_dir, mode='train')
    test_dataset = LacDataset(args.data_dir, mode='test')

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=0),  # word_ids
        Stack(),  # length
        Pad(axis=0, pad_val=0),  # label_ids
    ): fn(samples)

    # Create sampler for dataloader
    train_sampler = paddle.io.DistributedBatchSampler(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True)
    train_loader = paddle.io.DataLoader(dataset=train_dataset,
                                        batch_sampler=train_sampler,
                                        return_list=True,
                                        collate_fn=batchify_fn)

    test_sampler = paddle.io.BatchSampler(dataset=test_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          drop_last=False)
    test_loader = paddle.io.DataLoader(dataset=test_dataset,
                                       batch_sampler=test_sampler,
                                       return_list=True,
                                       collate_fn=batchify_fn)

    # Define the model netword and its loss
    network = BiGruCrf(args.emb_dim, args.hidden_size,
                       train_dataset.vocab_size, train_dataset.num_labels)
    model = paddle.Model(network)

    # Prepare optimizer, loss and metric evaluator
    optimizer = paddle.optimizer.Adam(learning_rate=args.base_lr,
                                      parameters=model.parameters())
    crf_loss = LinearChainCrfLoss(network.crf)
    chunk_evaluator = ChunkEvaluator(
        label_list=train_dataset.label_vocab.keys(), suffix=True)
    model.prepare(optimizer, crf_loss, chunk_evaluator)
    if args.init_checkpoint:
        model.load(args.init_checkpoint)

    # Start training
    callbacks = paddle.callbacks.ProgBarLogger(
        log_freq=10, verbose=3) if args.verbose else None
    model.fit(train_data=train_loader,
              eval_data=test_loader,
              batch_size=args.batch_size,
              epochs=args.epochs,
              eval_freq=1,
              log_freq=10,
              save_dir=args.model_save_dir,
              save_freq=1,
              shuffle=True,
              callbacks=callbacks)
Ejemplo n.º 7
0
def train(args):
    paddle.set_device(args.device)

    # Create dataset.
    train_ds, test_ds = load_dataset(
        datafiles=(os.path.join(args.data_dir, 'train.tsv'),
                   os.path.join(args.data_dir, 'test.tsv')))

    word_vocab = load_vocab(os.path.join(args.data_dir, 'word.dic'))
    label_vocab = load_vocab(os.path.join(args.data_dir, 'tag.dic'))
    # q2b.dic is used to replace DBC case to SBC case
    normlize_vocab = load_vocab(os.path.join(args.data_dir, 'q2b.dic'))

    trans_func = partial(convert_example,
                         max_seq_len=args.max_seq_len,
                         word_vocab=word_vocab,
                         label_vocab=label_vocab,
                         normlize_vocab=normlize_vocab)
    train_ds.map(trans_func)
    test_ds.map(trans_func)

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=0, dtype='int64'),  # word_ids
        Stack(dtype='int64'),  # length
        Pad(axis=0, pad_val=0, dtype='int64'),  # label_ids
    ): fn(samples)

    # Create sampler for dataloader
    train_sampler = paddle.io.DistributedBatchSampler(
        dataset=train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True)
    train_loader = paddle.io.DataLoader(dataset=train_ds,
                                        batch_sampler=train_sampler,
                                        return_list=True,
                                        collate_fn=batchify_fn)

    test_sampler = paddle.io.BatchSampler(dataset=test_ds,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          drop_last=False)
    test_loader = paddle.io.DataLoader(dataset=test_ds,
                                       batch_sampler=test_sampler,
                                       return_list=True,
                                       collate_fn=batchify_fn)

    # Define the model netword and its loss
    model = BiGruCrf(args.emb_dim, args.hidden_size, len(word_vocab),
                     len(label_vocab))
    # Prepare optimizer, loss and metric evaluator
    optimizer = paddle.optimizer.Adam(learning_rate=args.base_lr,
                                      parameters=model.parameters())
    chunk_evaluator = ChunkEvaluator(label_list=label_vocab.keys(),
                                     suffix=True)

    if args.init_checkpoint:
        model_dict = paddle.load(args.init_checkpoint)
        model.load_dict(model_dict)

    # Start training
    global_step = 0
    last_step = args.epochs * len(train_loader)
    tic_train = time.time()
    for epoch in range(args.epochs):
        for step, batch in enumerate(train_loader):
            global_step += 1
            token_ids, length, label_ids = batch
            loss = model(token_ids, length, label_ids)
            avg_loss = paddle.mean(loss)
            if global_step % args.logging_steps == 0:
                print("global step %d / %d, loss: %f, speed: %.2f step/s" %
                      (global_step, last_step, avg_loss, args.logging_steps /
                       (time.time() - tic_train)))
                tic_train = time.time()
            avg_loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            if global_step % args.save_steps == 0 or global_step == last_step:
                if paddle.distributed.get_rank() == 0:
                    evaluate(model, chunk_evaluator, test_loader)
                    paddle.save(
                        model.state_dict(),
                        os.path.join(args.model_save_dir,
                                     "model_%d.pdparams" % global_step))
Ejemplo n.º 8
0
def infer(args):
    paddle.set_device(args.device)

    # create dataset.
    infer_ds = load_dataset(datafiles=(os.path.join(args.data_dir,
                                                    'infer.tsv')))
    word_vocab = load_vocab(os.path.join(args.data_dir, 'word.dic'))
    label_vocab = load_vocab(os.path.join(args.data_dir, 'tag.dic'))
    # q2b.dic is used to replace DBC case to SBC case
    normlize_vocab = load_vocab(os.path.join(args.data_dir, 'q2b.dic'))

    trans_func = partial(
        convert_example,
        max_seq_len=args.max_seq_len,
        word_vocab=word_vocab,
        label_vocab=label_vocab,
        normlize_vocab=normlize_vocab)
    infer_ds.map(trans_func)

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=0, dtype='int64'),  # word_ids
        Stack(dtype='int64'),  # length
    ): fn(samples)

    # Create sampler for dataloader
    infer_sampler = paddle.io.BatchSampler(
        dataset=infer_ds,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False)
    infer_loader = paddle.io.DataLoader(
        dataset=infer_ds,
        batch_sampler=infer_sampler,
        return_list=True,
        collate_fn=batchify_fn)

    # Define the model network
    model = BiGruCrf(args.emb_dim, args.hidden_size,
                     len(word_vocab), len(label_vocab))

    # Load the model and start predicting
    model_dict = paddle.load(args.init_checkpoint)
    model.load_dict(model_dict)

    model.eval()
    results = []
    for batch in infer_loader:
        token_ids, length = batch
        preds = model(token_ids, length)
        result = parse_result(token_ids.numpy(),
                              preds.numpy(),
                              length.numpy(), word_vocab, label_vocab)
        results += result

    sent_tags = []
    for sent, tags in results:
        sent_tag = ['(%s, %s)' % (ch, tag) for ch, tag in zip(sent, tags)]
        sent_tags.append(''.join(sent_tag))

    file_path = "results.txt"
    with open(file_path, "w", encoding="utf8") as fout:
        fout.write("\n".join(sent_tags))

    # Print some examples
    print(
        "The results have been saved in the file: %s, some examples are shown below: "
        % file_path)
    print("\n".join(sent_tags[:10]))
Ejemplo n.º 9
0
def train(args):
    paddle.set_device(args.device)
    set_seed(102)
    trainer_num = paddle.distributed.get_world_size()
    if trainer_num > 1:
        paddle.distributed.init_parallel_env()
    rank = paddle.distributed.get_rank()

    word_vocab, label_vocab, train_loader, test_loader = create_data_loader(
        args)

    # Define the model netword and its loss
    model = BiGruCrf(args.emb_dim,
                     args.hidden_size,
                     len(word_vocab),
                     len(label_vocab),
                     crf_lr=args.crf_lr)
    # Prepare optimizer, loss and metric evaluator
    optimizer = paddle.optimizer.Adam(learning_rate=args.base_lr,
                                      parameters=model.parameters())
    chunk_evaluator = ChunkEvaluator(label_list=label_vocab.keys(),
                                     suffix=True)

    if args.init_checkpoint:
        if os.path.exists(args.init_checkpoint):
            logger.info("Init checkpoint from %s" % args.init_checkpoint)
            model_dict = paddle.load(args.init_checkpoint)
            model.load_dict(model_dict)
        else:
            logger.info("Cannot init checkpoint from %s which doesn't exist" %
                        args.init_checkpoint)
    logger.info("Start training")
    # Start training
    global_step = 0
    last_step = args.epochs * len(train_loader)
    train_reader_cost = 0.0
    train_run_cost = 0.0
    total_samples = 0
    reader_start = time.time()
    max_f1_score = -1
    for epoch in range(args.epochs):
        for step, batch in enumerate(train_loader):
            train_reader_cost += time.time() - reader_start
            global_step += 1
            token_ids, length, label_ids = batch
            train_start = time.time()
            loss = model(token_ids, length, label_ids)
            avg_loss = paddle.mean(loss)
            train_run_cost += time.time() - train_start
            total_samples += args.batch_size
            if global_step % args.logging_steps == 0:
                logger.info(
                    "global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
                    %
                    (global_step, last_step, avg_loss,
                     train_reader_cost / args.logging_steps,
                     (train_reader_cost + train_run_cost) / args.logging_steps,
                     total_samples / args.logging_steps, total_samples /
                     (train_reader_cost + train_run_cost)))
                train_reader_cost = 0.0
                train_run_cost = 0.0
                total_samples = 0
            avg_loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            if global_step % args.save_steps == 0 or global_step == last_step:
                if rank == 0:
                    paddle.save(
                        model.state_dict(),
                        os.path.join(args.model_save_dir,
                                     "model_%d.pdparams" % global_step))
                    logger.info("Save %d steps model." % (global_step))
                    if args.do_eval:
                        precision, recall, f1_score = evaluate(
                            model, chunk_evaluator, test_loader)
                        if f1_score > max_f1_score:
                            max_f1_score = f1_score
                            paddle.save(
                                model.state_dict(),
                                os.path.join(args.model_save_dir,
                                             "best_model.pdparams"))
                            logger.info("Save best model.")

            reader_start = time.time()
Ejemplo n.º 10
0
def train(args):
    if args.use_gpu:
        place = paddle.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
        paddle.set_device("gpu")
    else:
        place = paddle.CPUPlace()
        paddle.set_device("cpu")

    # create dataset.
    train_dataset = LacDataset(args.data_dir, mode='train')
    test_dataset = LacDataset(args.data_dir, mode='test')

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=0),  # word_ids
        Stack(),  # length
        Pad(axis=0, pad_val=0),  # label_ids
    ): fn(samples)

    # Create sampler for dataloader
    train_sampler = paddle.io.DistributedBatchSampler(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True)
    train_loader = paddle.io.DataLoader(
        dataset=train_dataset,
        batch_sampler=train_sampler,
        places=place,
        return_list=True,
        collate_fn=batchify_fn)

    test_sampler = paddle.io.BatchSampler(
        dataset=test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=True)
    test_loader = paddle.io.DataLoader(
        dataset=test_dataset,
        batch_sampler=test_sampler,
        places=place,
        return_list=True,
        collate_fn=batchify_fn)

    # Define the model netword and its loss
    network = BiGruCrf(args.emb_dim, args.hidden_size, train_dataset.vocab_size,
                       train_dataset.num_labels)
    model = paddle.Model(network)

    # Prepare optimizer, loss and metric evaluator
    optimizer = paddle.optimizer.Adam(
        learning_rate=args.base_lr, parameters=model.parameters())
    crf_loss = LinearChainCrfLoss(network.crf.transitions)
    chunk_evaluator = ChunkEvaluator(
        int(math.ceil((train_dataset.num_labels + 1) / 2.0)),
        "IOB")  # + 1 for START and STOP
    model.prepare(optimizer, crf_loss, chunk_evaluator)
    if args.init_checkpoint:
        model.load(args.init_checkpoint)

    # Start training
    callback = paddle.callbacks.ProgBarLogger(log_freq=10, verbose=3)
    model.fit(train_data=train_loader,
              eval_data=test_loader,
              batch_size=args.batch_size,
              epochs=args.epochs,
              eval_freq=1,
              log_freq=10,
              save_dir=args.model_save_dir,
              save_freq=1,
              drop_last=True,
              shuffle=True,
              callbacks=callback)
Ejemplo n.º 11
0
def train(args):
    paddle.set_device(args.device)

    trainer_num = paddle.distributed.get_world_size()
    if trainer_num > 1:
        paddle.distributed.init_parallel_env()
    rank = paddle.distributed.get_rank()
    # Create dataset.
    train_ds, test_ds = load_dataset(
        datafiles=(os.path.join(args.data_dir, 'train.tsv'),
                   os.path.join(args.data_dir, 'test.tsv')))

    word_vocab = load_vocab(os.path.join(args.data_dir, 'word.dic'))
    label_vocab = load_vocab(os.path.join(args.data_dir, 'tag.dic'))
    # q2b.dic is used to replace DBC case to SBC case
    normlize_vocab = load_vocab(os.path.join(args.data_dir, 'q2b.dic'))

    trans_func = partial(convert_example,
                         max_seq_len=args.max_seq_len,
                         word_vocab=word_vocab,
                         label_vocab=label_vocab,
                         normlize_vocab=normlize_vocab)
    train_ds.map(trans_func)
    test_ds.map(trans_func)

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=word_vocab.get("[PAD]", 0), dtype='int64'
            ),  # word_ids
        Stack(dtype='int64'),  # length
        Pad(axis=0, pad_val=label_vocab.get("O", 0), dtype='int64'
            ),  # label_ids
    ): fn(samples)

    # Create sampler for dataloader
    train_sampler = paddle.io.DistributedBatchSampler(
        dataset=train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True)
    train_loader = paddle.io.DataLoader(dataset=train_ds,
                                        batch_sampler=train_sampler,
                                        return_list=True,
                                        collate_fn=batchify_fn)

    test_sampler = paddle.io.BatchSampler(dataset=test_ds,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          drop_last=False)
    test_loader = paddle.io.DataLoader(dataset=test_ds,
                                       batch_sampler=test_sampler,
                                       return_list=True,
                                       collate_fn=batchify_fn)

    # Define the model netword and its loss
    model = BiGruCrf(args.emb_dim,
                     args.hidden_size,
                     len(word_vocab),
                     len(label_vocab),
                     crf_lr=args.crf_lr)
    # Prepare optimizer, loss and metric evaluator
    optimizer = paddle.optimizer.Adam(learning_rate=args.base_lr,
                                      parameters=model.parameters())
    chunk_evaluator = ChunkEvaluator(label_list=label_vocab.keys(),
                                     suffix=True)

    if args.init_checkpoint:
        if os.path.exists(args.init_checkpoint):
            logger.info("Init checkpoint from %s" % args.init_checkpoint)
            model_dict = paddle.load(args.init_checkpoint)
            model.load_dict(model_dict)
        else:
            logger.info("Cannot init checkpoint from %s which doesn't exist" %
                        args.init_checkpoint)
    logger.info("Start training")
    # Start training
    global_step = 0
    last_step = args.epochs * len(train_loader)
    train_reader_cost = 0.0
    train_run_cost = 0.0
    total_samples = 0
    reader_start = time.time()
    max_f1_score = -1
    for epoch in range(args.epochs):
        for step, batch in enumerate(train_loader):
            train_reader_cost += time.time() - reader_start
            global_step += 1
            token_ids, length, label_ids = batch
            train_start = time.time()
            loss = model(token_ids, length, label_ids)
            avg_loss = paddle.mean(loss)
            train_run_cost += time.time() - train_start
            total_samples += args.batch_size
            if global_step % args.logging_steps == 0:
                logger.info(
                    "global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
                    %
                    (global_step, last_step, avg_loss,
                     train_reader_cost / args.logging_steps,
                     (train_reader_cost + train_run_cost) / args.logging_steps,
                     total_samples / args.logging_steps, total_samples /
                     (train_reader_cost + train_run_cost)))
                train_reader_cost = 0.0
                train_run_cost = 0.0
                total_samples = 0
            avg_loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            if global_step % args.save_steps == 0 or global_step == last_step:
                if rank == 0:
                    paddle.save(
                        model.state_dict(),
                        os.path.join(args.model_save_dir,
                                     "model_%d.pdparams" % global_step))
                    logger.info("Save %d steps model." % (global_step))
                    if args.do_eval:
                        precision, recall, f1_score = evaluate(
                            model, chunk_evaluator, test_loader)
                        if f1_score > max_f1_score:
                            max_f1_score = f1_score
                            paddle.save(
                                model.state_dict(),
                                os.path.join(args.model_save_dir,
                                             "best_model.pdparams"))
                            logger.info("Save best model.")

            reader_start = time.time()