コード例 #1
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path",
                        default=None,
                        type=str,
                        required=True,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path",
                        default="./models/tagger_model.bin",
                        type=str,
                        help="Path of the output model.")
    parser.add_argument("--output_encoder",
                        default="./luke-models/",
                        type=str,
                        help="Path of the output luke model.")
    parser.add_argument("--suffix_file_encoder",
                        default="encoder",
                        type=str,
                        help="output file suffix luke model.")
    parser.add_argument("--vocab_path",
                        default="./models/google_vocab.txt",
                        type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path",
                        type=str,
                        required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path",
                        type=str,
                        required=True,
                        help="Path of the devset.")
    parser.add_argument("--test_path",
                        type=str,
                        required=True,
                        help="Path of the testset.")
    parser.add_argument("--config_path",
                        default="./models/google_config.json",
                        type=str,
                        help="Path of the config file.")
    parser.add_argument("--output_file_prefix",
                        type=str,
                        required=True,
                        help="Prefix for file output.")
    parser.add_argument("--log_file", default='app.log')

    # Model options.
    parser.add_argument("--seq_length",
                        default=256,
                        type=int,
                        help="Sequence length.")
    parser.add_argument("--classifier",
                        choices=["mlp", "lstm", "lstm_crf", "lstm_ncrf"],
                        default="mlp",
                        help="Classifier type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")
    parser.add_argument('--freeze_encoder_weights',
                        action='store_true',
                        help="Enable to freeze the encoder weigths.")

    # Subword options.
    parser.add_argument("--subword_type",
                        choices=["none", "char"],
                        default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path",
                        type=str,
                        default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder",
                        choices=["avg", "lstm", "gru", "cnn"],
                        default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num",
                        type=int,
                        default=2,
                        help="The number of subencoder layers.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout.")
    parser.add_argument("--epochs_num",
                        type=int,
                        default=0,
                        help="Number of epochs.")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=2,
                        help="Number of steps to accumulate the gradient.")
    parser.add_argument("--report_steps",
                        type=int,
                        default=200,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=35, help="Random seed.")
    parser.add_argument("--batch_size",
                        type=int,
                        default=32,
                        help="Batch_size.")
    parser.add_argument("--num_train_steps",
                        type=int,
                        default=20000,
                        help="Max steps to be trained.")
    parser.add_argument("--patience",
                        type=int,
                        default=8000,
                        help="Specific steps to wait until stops training.")

    # Optimizer options.
    parser.add_argument("--learning_rate", default=1e-5, type=float)
    parser.add_argument("--lr_schedule",
                        default="warmup_linear",
                        type=str,
                        choices=["warmup_linear", "warmup_constant"])
    parser.add_argument("--weight_decay", default=0.01, type=float)
    parser.add_argument("--max_grad_norm", default=0.0, type=float)
    parser.add_argument("--adam_b1", default=0.9, type=float)
    parser.add_argument("--adam_b2", default=0.999, type=float)
    parser.add_argument("--adam_eps", default=1e-8, type=float)
    parser.add_argument("--adam_correct_bias", action='store_true')
    parser.add_argument("--warmup_proportion", default=0.006, type=float)
    parser.add_argument("--freeze_proportions", default=0.0, type=float)
    parser.add_argument("--wandb",
                        action='store_true',
                        help="Enable wandb logging")

    # kg
    parser.add_argument("--kg_name", type=str, help="KG name or path")
    parser.add_argument("--use_kg",
                        action='store_true',
                        help="Enable the use of KG.")
    parser.add_argument("--padding",
                        action='store_true',
                        help="Enable padding.")
    parser.add_argument(
        "--truncate",
        action='store_true',
        help="Enable truncation if length is more than seq length.")
    parser.add_argument("--shuffle",
                        action='store_true',
                        help="Enable shuffling during training.")
    parser.add_argument("--dry_run",
                        action='store_true',
                        help="Dry run to test the implementation.")
    parser.add_argument(
        "--voting_choicer",
        action='store_true',
        help="Enable the Voting choicer to select the entity type.")
    parser.add_argument("--eval_kg_tag",
                        action='store_true',
                        help="Enable to include [ENT] tag in evaluation.")
    parser.add_argument("--use_subword_tag",
                        action='store_true',
                        help="Enable to use separate tag for subword splits.")
    parser.add_argument("--debug", action='store_true', help="Enable debug.")
    parser.add_argument("--reverse_order",
                        action='store_true',
                        help="Reverse the feature selection order.")
    parser.add_argument("--max_entities",
                        default=2,
                        type=int,
                        help="Number of KG features.")
    parser.add_argument("--eval_range_with_types",
                        action='store_true',
                        help="Enable to eval range with types.")

    args = parser.parse_args()

    # Load the hyperparameters of the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    logging.basicConfig(filename=args.log_file, filemode='w', format=fmt)

    labels_map = {"[PAD]": 0, "[ENT]": 1, "[X]": 2, "[CLS]": 3, "[SEP]": 4}
    begin_ids = []

    # Find tagging labels
    for file in (args.train_path, args.dev_path, args.test_path):
        with open(file, mode="r", encoding="utf-8") as f:
            for line_id, line in enumerate(f):
                if line_id == 0:
                    continue
                labels = line.strip().split("\t")[0].split()
                for l in labels:
                    if l not in labels_map:
                        if l.startswith("B") or l.startswith("S"):
                            begin_ids.append(len(labels_map))
                            # check if I-TAG exists
                            infix = l[1]
                            tag = l[2:]
                            inner_tag = f'I{infix}{tag}'
                            if inner_tag not in labels_map:
                                labels_map[inner_tag] = len(labels_map)

                        labels_map[l] = len(labels_map)

    idx_to_label = {labels_map[key]: key for key in labels_map}

    print(begin_ids)
    print("Labels: ", labels_map)
    args.labels_num = len(labels_map)

    # Build knowledge graph.
    if args.kg_name == 'none':
        kg_file = []
    else:
        kg_file = args.kg_name

    # Load Luke model.
    model_archive = ModelArchive.load(args.pretrained_model_path)
    tokenizer = model_archive.tokenizer

    # Handling space character in roberta tokenizer
    byte_encoder = bytes_to_unicode()
    byte_decoder = {v: k for k, v in byte_encoder.items()}

    # Load the pretrained model
    encoder = LukeModel(model_archive.config)
    encoder.load_state_dict(model_archive.state_dict, strict=False)

    kg = KnowledgeGraph(kg_file=kg_file, tokenizer=tokenizer)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    # Build sequence labeling model.
    classifiers = {
        "mlp": LukeTaggerMLP,
        "lstm": LukeTaggerLSTM,
        "lstm_crf": LukeTaggerLSTMCRF,
        "lstm_ncrf": LukeTaggerLSTMNCRF
    }
    logger.info(f'The selected classifier is:{classifiers[args.classifier]}')
    model = classifiers[args.classifier](args, encoder)
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = nn.DataParallel(model)
    model = model.to(device)

    # Read dataset.
    def read_dataset(path):
        dataset = []
        count = 0
        with open(path, mode="r", encoding="utf8") as f:
            f.readline()
            tokens, labels = [], []
            for line_id, line in enumerate(f):
                fields = line.strip().split("\t")
                if len(fields) == 2:
                    labels, tokens = fields
                elif len(fields) == 3:
                    labels, tokens, cls = fields
                else:
                    print(
                        f'The data is not in accepted format at line no:{line_id}.. Ignored'
                    )
                    continue

                tokens, pos, vm, tag = kg.add_knowledge_with_vm(
                    args, [tokens], [labels])
                tokens = tokens[0]
                pos = pos[0]
                vm = vm[0].astype("bool")
                tag = tag[0]

                # tokens = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokens + [tokenizer.sep_token])
                non_pad_tokens = [
                    tok for tok in tokens if tok != tokenizer.pad_token
                ]
                num_tokens = len(non_pad_tokens)
                num_pad = len(tokens) - num_tokens

                labels = [config.CLS_TOKEN
                          ] + labels.split(" ") + [config.SEP_TOKEN]
                new_labels = []
                j = 0
                joiner = '-'
                for i in range(len(tokens)):
                    if tag[i] == 0 and tokens[i] != tokenizer.pad_token:
                        cur_type = labels[j]
                        if cur_type != 'O':
                            try:
                                joiner = cur_type[1]
                                prev_label = cur_type[2:]
                            except:
                                logger.info(
                                    f'The label:{cur_type} is converted to O')
                                prev_label = 'O'
                                j += 1
                                new_labels.append('O')
                                continue
                        else:
                            prev_label = cur_type

                        new_labels.append(cur_type)
                        j += 1

                    elif tag[i] == 1 and tokens[
                            i] != tokenizer.pad_token:  # 是添加的实体
                        new_labels.append('[ENT]')
                    elif tag[i] == 2:
                        if prev_label == 'O':
                            new_labels.append('O')
                        else:
                            if args.use_subword_tag:
                                new_labels.append('[X]')
                            else:
                                new_labels.append(f'I{joiner}' + prev_label)
                    else:
                        new_labels.append(PAD_TOKEN)

                new_labels = [labels_map[l] for l in new_labels]

                # print(tokens)
                # print(labels)
                # print(tag)
                if num_pad != 0:
                    print(num_pad)
                    exit()
                mask = [1] * (num_tokens) + [0] * num_pad
                word_segment_ids = [0] * (len(tokens))

                # print(len(tokens))
                # print(len(tag))
                # exit()
                # print(tokenizer.pad_token_id)

                # for i in range(len(tokens)):
                #     if tag[i] == 0 and tokens[i] != tokenizer.pad_token:
                #         new_labels.append(labels[j])
                #         j += 1
                #     elif tag[i] == 1 and tokens[i] != tokenizer.pad_token:  # 是添加的实体
                #         new_labels.append(labels_map['[ENT]'])
                #     elif tag[i] == 2:
                #         if args.use_subword_tag:
                #             new_labels.append(labels_map['[X]'])
                #         else:
                #             new_labels.append(labels_map['[ENT]'])
                #     else:
                #         new_labels.append(labels_map[PAD_TOKEN])

                # print(labels)
                # print(new_labels)
                # print([idx_to_label.get(key) for key in labels])
                # print([idx_to_label.get(key) for key in labels])
                # print(mask)
                # print(pos)
                # print(word_segment_ids)
                # print(tokens)
                # tokens = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokens + [tokenizer.sep_token])
                tokens = tokenizer.convert_tokens_to_ids(tokens)
                # print(tokens)
                # exit()
                assert len(tokens) == len(new_labels), AssertionError(
                    "The length of token and label is not matching")

                dataset.append(
                    [tokens, new_labels, mask, pos, vm, tag, word_segment_ids])

                # Enable dry rune
                if args.dry_run:
                    count += 1
                    if count == 100:
                        break

        return dataset

    # Evaluation function.
    def evaluate(args, is_test, final=False):
        if is_test:
            dataset = read_dataset(args.test_path)
        else:
            dataset = read_dataset(args.dev_path)

        instances_num = len(dataset)
        batch_size = args.batch_size

        if is_test:
            logger.info(f"Batch size:{batch_size}")
            print(f"The number of test instances:{instances_num}")

        true_labels_all = []
        predicted_labels_all = []
        confusion = torch.zeros(len(labels_map),
                                len(labels_map),
                                dtype=torch.long)
        model.eval()

        test_batcher = Batcher(args,
                               dataset,
                               token_pad=tokenizer.pad_token_id,
                               label_pad=labels_map[PAD_TOKEN])

        for i, (input_ids_batch, label_ids_batch, mask_ids_batch,
                pos_ids_batch, vm_ids_batch,
                segment_ids_batch) in enumerate(test_batcher):

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            pos_ids_batch = pos_ids_batch.to(device)
            vm_ids_batch = vm_ids_batch.long().to(device)
            segment_ids_batch = segment_ids_batch.long().to(device)

            pred, logits, scores = model(input_ids_batch,
                                         segment_ids_batch,
                                         mask_ids_batch,
                                         label_ids_batch,
                                         pos_ids_batch,
                                         vm_ids_batch,
                                         use_kg=args.use_kg)

            for pred_sample, gold_sample, mask in zip(pred, label_ids_batch,
                                                      mask_ids_batch):

                pred_labels = [
                    idx_to_label.get(key) for key in pred_sample.tolist()
                ]
                gold_labels = [
                    idx_to_label.get(key) for key in gold_sample.tolist()
                ]

                num_labels = sum(mask)

                # Exclude the [CLS], and [SEP] tokens
                pred_labels = pred_labels[1:num_labels - 1]
                true_labels = gold_labels[1:num_labels - 1]

                pred_labels = [p.replace('_NOKG', '') for p in pred_labels]
                true_labels = [t.replace('_NOKG', '') for t in true_labels]

                true_labels, pred_labels = filter_kg_labels(
                    true_labels, pred_labels)

                pred_labels = [p.replace('_', '-') for p in pred_labels]
                true_labels = [t.replace('_', '-') for t in true_labels]

                biluo_tags_predicted = get_bio(pred_labels)
                biluo_tags_true = get_bio(true_labels)

                if len(biluo_tags_predicted) != len(biluo_tags_true):
                    logger.error(
                        'The length of the predicted labels is not same as that of true labels..'
                    )
                    exit()

                predicted_labels_all.append(biluo_tags_predicted)
                true_labels_all.append(biluo_tags_true)

        if final:
            with open(f'{args.output_file_prefix}_predictions.txt', 'a') as p, \
                    open(f'{args.output_file_prefix}_gold.txt', 'a') as g:
                p.write('\n'.join([' '.join(l) for l in predicted_labels_all]))
                g.write('\n'.join([' '.join(l) for l in true_labels_all]))

        return dict(
            f1=seqeval.metrics.f1_score(true_labels_all, predicted_labels_all),
            precision=seqeval.metrics.precision_score(true_labels_all,
                                                      predicted_labels_all),
            recall=seqeval.metrics.recall_score(true_labels_all,
                                                predicted_labels_all),
            f1_span=f1_score_span(true_labels_all, predicted_labels_all),
            precision_span=precision_score_span(true_labels_all,
                                                predicted_labels_all),
            recall_span=recall_score_span(true_labels_all,
                                          predicted_labels_all),
        )

    # Training phase.
    logger.info("Start training.")
    instances = read_dataset(args.train_path)

    instances_num = len(instances)
    batch_size = args.batch_size

    if args.epochs_num:
        args.num_train_steps = int(
            instances_num * args.epochs_num / batch_size) + 1

    unfreeze_steps = 0
    model_frozen = False
    if args.freeze_proportions != 0.0:
        unfreeze_steps = int(
            args.num_train_steps * args.freeze_proportions) + 1
        logger.info(
            f'Two phase training is enabled with model unfreeze at:{unfreeze_steps}'
        )
        # freeze the model
        model.freeze()
        model_frozen = True

    logger.info(f"Batch size:{batch_size}")
    logger.info(f"The number of training instances:{instances_num}")

    train_batcher = Batcher(args,
                            instances,
                            token_pad=tokenizer.pad_token_id,
                            label_pad=labels_map[PAD_TOKEN])

    optimizer = create_optimizer(args, model)
    scheduler = create_scheduler(args, optimizer)
    total_loss = 0.
    best_f1 = 0.0

    # Dry evaluate
    # evaluate(args, True)

    def maybe_no_sync(step):
        if (hasattr(model, "no_sync")
                and (step + 1) % args.gradient_accumulation_steps != 0):
            return model.no_sync()
        else:
            return contextlib.ExitStack()

    # YOU MUST LOG INTO WANDB WITH YOUR OWN ACCOUNT
    if args.wandb:
        import wandb
        wandb.init(project="kbert_pretrain")
        # args.update(wandb.config)
        print(f'new args{args}')
    else:
        wandb = None

    global_steps = 0
    early_stop_steps = 0
    epoch = 0

    with tqdm(total=args.num_train_steps) as pbar:
        while True:
            model.train()
            for step, (input_ids_batch, label_ids_batch, mask_ids_batch,
                       pos_ids_batch, vm_ids_batch,
                       segment_ids_batch) in enumerate(train_batcher):

                input_ids_batch = input_ids_batch.to(device)
                label_ids_batch = label_ids_batch.to(device)
                mask_ids_batch = mask_ids_batch.to(device)
                pos_ids_batch = pos_ids_batch.to(device)
                vm_ids_batch = vm_ids_batch.long().to(device)
                segment_ids_batch = segment_ids_batch.long().to(device)

                loss, logits = model.score(input_ids_batch,
                                           segment_ids_batch,
                                           mask_ids_batch,
                                           label_ids_batch,
                                           pos_ids_batch,
                                           vm_ids_batch,
                                           use_kg=args.use_kg)

                if torch.cuda.device_count() > 1:
                    loss = torch.mean(loss)

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                with maybe_no_sync(step):
                    loss.backward()

                total_loss += loss.item()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.max_grad_norm != 0.0:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
                    optimizer.zero_grad()

                    pbar.set_description("epoch: %d loss: %.7f" %
                                         (epoch, loss.item()))
                    pbar.update()
                    global_steps += 1

                    if global_steps % args.report_steps == 0:
                        logger.info("Epoch id: {}, Global Steps:{}, Avg loss: "
                                    "{:.10f}".format(
                                        epoch, global_steps + 1,
                                        total_loss / args.report_steps))

                        # Evaluation phase.
                        logger.info("Start evaluate on dev dataset.")
                        results = evaluate(args, False)
                        logger.info(results)

                        logger.info("Start evaluation on test dataset.")
                        results_test = evaluate(args, True)
                        logger.info(results_test)

                        avg_loss = total_loss / args.report_steps

                        if args.wandb:
                            # Log the loss and accuracy values at the end of each epoch
                            wandb.log({
                                "steps": global_steps,
                                "train Loss": avg_loss,
                                "valid_acc": results['f1'],
                                "test_acc": results_test['f1'],
                                "learning_rate": args.learning_rate,
                                "batch_size": args.batch_size,
                                "lr_schedule": args.lr_schedule,
                                "weight_decay": args.weight_decay,
                                "max_grad_norm": args.max_grad_norm,
                            })

                        if results['f1'] > best_f1:
                            best_f1 = results['f1']
                            early_stop_steps = 0
                            save_model(model, args.output_model_path)
                            save_encoder(args,
                                         encoder,
                                         suffix=args.suffix_file_encoder)
                        else:
                            early_stop_steps += args.report_steps

                        # Change back the model for training
                        model.train()
                        total_loss = 0.

                if model_frozen and global_steps >= unfreeze_steps:
                    # unfreeze the model and start training
                    logger.info('The encoder is unfrozen for training.')
                    model.unfreeze()
                    model_frozen = False

                if global_steps >= args.num_train_steps:
                    # Training completed
                    logger.info('The training is completed!')
                    break

                if early_stop_steps >= args.patience:
                    # Early stopping
                    logger.info('The early stopping is triggered!')
                    break

            if model_frozen and global_steps >= unfreeze_steps:
                # unfreeze the model and start training
                logger.info('The encoder is unfrozen for training.')
                model.unfreeze()
                model_frozen = False

            if global_steps >= args.num_train_steps:
                # Training completed
                logger.info('The training is completed!')
                break

            if early_stop_steps >= args.patience:
                # Early stopping
                logger.info('The early stopping is triggered!')
                break

            epoch += 1

        # Evaluation phase.
        logger.info("Final evaluation on test dataset.")
        if torch.cuda.device_count() > 1:
            model.module.load_state_dict(torch.load(args.output_model_path))
        else:
            model.load_state_dict(torch.load(args.output_model_path))
        results_final = evaluate(args, True, final=True)
        logger.info(results_final)
コード例 #2
0
ファイル: check_kg.py プロジェクト: patelrajnath/K-BERT
from brain.knowgraph_english import KnowledgeGraph

vocab_file = "D:\Downloads\ent_vocab_custom"
kg = KnowledgeGraph(kg_file=vocab_file, predicate=True)
text = "Delhi is the capital of India ."
tokens, pos, vm, tag = kg.add_knowledge_with_vm([text],
                                                add_pad=False,
                                                max_length=16)
print(tag)
print(pos)
print(tokens)
print(vm)
コード例 #3
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path",
                        default=None,
                        type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path",
                        default="./models/tagger_model.bin",
                        type=str,
                        help="Path of the output model.")
    parser.add_argument("--output_encoder",
                        default="./luke-models/",
                        type=str,
                        help="Path of the output luke model.")
    parser.add_argument("--suffix_file_encoder",
                        default="encoder",
                        type=str,
                        help="output file suffix luke model.")
    parser.add_argument("--vocab_path",
                        default="./models/google_vocab.txt",
                        type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path",
                        type=str,
                        required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path",
                        type=str,
                        required=True,
                        help="Path of the devset.")
    parser.add_argument("--test_path",
                        type=str,
                        required=True,
                        help="Path of the testset.")
    parser.add_argument("--config_path",
                        default="./models/google_config.json",
                        type=str,
                        help="Path of the config file.")
    parser.add_argument("--output_file_prefix",
                        type=str,
                        required=True,
                        help="Prefix for file output.")

    # Model options.
    parser.add_argument("--batch_size",
                        type=int,
                        default=2,
                        help="Batch_size.")
    parser.add_argument("--seq_length",
                        default=256,
                        type=int,
                        help="Sequence length.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                              "cnn", "gatedcnn", "attn", \
                                              "rcnn", "crnn", "gpt", "bilstm"], \
                        default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")

    # Subword options.
    parser.add_argument("--subword_type",
                        choices=["none", "char"],
                        default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path",
                        type=str,
                        default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder",
                        choices=["avg", "lstm", "gru", "cnn"],
                        default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num",
                        type=int,
                        default=2,
                        help="The number of subencoder layers.")

    # Optimizer options.
    parser.add_argument("--learning_rate",
                        type=float,
                        default=2e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup",
                        type=float,
                        default=0.1,
                        help="Warm up value.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout.")
    parser.add_argument("--epochs_num",
                        type=int,
                        default=5,
                        help="Number of epochs.")
    parser.add_argument("--report_steps",
                        type=int,
                        default=2,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7, help="Random seed.")

    # kg
    parser.add_argument("--kg_name", required=True, help="KG name or path")
    parser.add_argument("--use_kg",
                        action='store_true',
                        help="Enable the use of KG.")
    parser.add_argument("--dry_run",
                        action='store_true',
                        help="Dry run to test the implementation.")
    parser.add_argument(
        "--voting_choicer",
        action='store_true',
        help="Enable the Voting choicer to select the entity type.")
    parser.add_argument("--eval_kg_tag",
                        action='store_true',
                        help="Enable to include [ENT] tag in evaluation.")
    parser.add_argument("--use_subword_tag",
                        action='store_true',
                        help="Enable to use separate tag for subword splits.")
    parser.add_argument("--debug", action='store_true', help="Enable debug.")
    parser.add_argument("--reverse_order",
                        action='store_true',
                        help="Reverse the feature selection order.")
    parser.add_argument("--max_entities",
                        default=2,
                        type=int,
                        help="Number of KG features.")
    parser.add_argument("--eval_range_with_types",
                        action='store_true',
                        help="Enable to eval range with types.")

    args = parser.parse_args()

    # Load the hyperparameters of the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    labels_map = {"[PAD]": 0, "[ENT]": 1, "[X]": 2, "[CLS]": 3, "[SEP]": 4}
    begin_ids = []

    # Find tagging labels
    with open(args.train_path, mode="r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            if line_id == 0:
                continue
            labels = line.strip().split("\t")[0].split()
            for l in labels:
                if l not in labels_map:
                    if l.startswith("B") or l.startswith("S"):
                        begin_ids.append(len(labels_map))
                    labels_map[l] = len(labels_map)

    idx_to_label = {labels_map[key]: key for key in labels_map}

    print(begin_ids)
    print("Labels: ", labels_map)
    args.labels_num = len(labels_map)

    # Build knowledge graph.
    if args.kg_name == 'none':
        kg_file = []
    else:
        kg_file = args.kg_name

    # Load Luke model.
    model_archive = ModelArchive.load(args.pretrained_model_path)
    tokenizer = model_archive.tokenizer

    # Handling space character in roberta tokenizer
    byte_encoder = bytes_to_unicode()
    byte_decoder = {v: k for k, v in byte_encoder.items()}

    # Load the pretrained model
    encoder = LukeModel(model_archive.config)
    encoder.load_state_dict(model_archive.state_dict, strict=False)

    # Build sequence labeling model.
    model = LukeTagger(args, encoder)
    kg = KnowledgeGraph(kg_file=kg_file, tokenizer=tokenizer)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = nn.DataParallel(model)

    model = model.to(device)

    # Datset loader.
    def batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids,
                     vm_ids, tag_ids, segment_ids):
        instances_num = input_ids.size()[0]
        for i in range(instances_num // batch_size):
            input_ids_batch = input_ids[i * batch_size:(i + 1) * batch_size, :]
            label_ids_batch = label_ids[i * batch_size:(i + 1) * batch_size, :]
            mask_ids_batch = mask_ids[i * batch_size:(i + 1) * batch_size, :]
            pos_ids_batch = pos_ids[i * batch_size:(i + 1) * batch_size, :]
            vm_ids_batch = vm_ids[i * batch_size:(i + 1) * batch_size, :, :]
            tag_ids_batch = tag_ids[i * batch_size:(i + 1) * batch_size, :]
            segment_ids_batch = segment_ids[i * batch_size:(i + 1) *
                                            batch_size, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch, segment_ids_batch
        if instances_num > instances_num // batch_size * batch_size:
            input_ids_batch = input_ids[instances_num // batch_size *
                                        batch_size:, :]
            label_ids_batch = label_ids[instances_num // batch_size *
                                        batch_size:, :]
            mask_ids_batch = mask_ids[instances_num // batch_size *
                                      batch_size:, :]
            pos_ids_batch = pos_ids[instances_num // batch_size *
                                    batch_size:, :]
            vm_ids_batch = vm_ids[instances_num // batch_size *
                                  batch_size:, :, :]
            tag_ids_batch = tag_ids[instances_num // batch_size *
                                    batch_size:, :]
            segment_ids_batch = segment_ids[instances_num // batch_size *
                                            batch_size:, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch, segment_ids_batch

    # Read dataset.
    def read_dataset(path):
        dataset = []
        count = 0
        with open(path, mode="r", encoding="utf8") as f:
            f.readline()
            tokens, labels = [], []
            for line_id, line in enumerate(f):
                fields = line.strip().split("\t")
                if len(fields) == 2:
                    labels, tokens = fields
                elif len(fields) == 3:
                    labels, tokens, cls = fields
                else:
                    print(
                        f'The data is not in accepted format at line no:{line_id}.. Ignored'
                    )
                    continue

                tokens, pos, vm, tag = \
                    kg.add_knowledge_with_vm([tokens], [labels],
                                             use_kg=args.use_kg,
                                             max_length=args.seq_length,
                                             max_entities=args.max_entities,
                                             reverse_order=args.reverse_order)
                tokens = tokens[0]
                pos = pos[0]
                vm = vm[0].astype("bool")
                tag = tag[0]

                # tokens = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokens + [tokenizer.sep_token])
                non_pad_tokens = [
                    tok for tok in tokens if tok != tokenizer.pad_token
                ]
                num_tokens = len(non_pad_tokens)
                num_pad = len(tokens) - num_tokens

                labels = [config.CLS_TOKEN
                          ] + labels.split(" ") + [config.SEP_TOKEN]
                new_labels = []
                j = 0
                joiner = '-'
                for i in range(len(tokens)):
                    if tag[i] == 0 and tokens[i] != tokenizer.pad_token:
                        cur_type = labels[j]
                        new_labels.append(cur_type)
                        if cur_type != 'O':
                            joiner = cur_type[1]
                            prev_label = cur_type[2:]
                        else:
                            prev_label = cur_type
                        j += 1
                    elif tag[i] == 1 and tokens[
                            i] != tokenizer.pad_token:  # 是添加的实体
                        new_labels.append('[ENT]')
                    elif tag[i] == 2:
                        if prev_label == 'O':
                            new_labels.append('O')
                        else:
                            if args.use_subword_tag:
                                new_labels.append('[X]')
                            else:
                                new_labels.append(f'I{joiner}' + prev_label)
                    else:
                        new_labels.append(PAD_TOKEN)

                new_labels = [labels_map[l] for l in new_labels]

                # print(tokens)
                # print(labels)
                # print(tag)

                mask = [1] * (num_tokens) + [0] * num_pad
                word_segment_ids = [0] * (len(tokens))

                # print(len(tokens))
                # print(len(tag))
                # exit()
                # print(tokenizer.pad_token_id)

                # for i in range(len(tokens)):
                #     if tag[i] == 0 and tokens[i] != tokenizer.pad_token:
                #         new_labels.append(labels[j])
                #         j += 1
                #     elif tag[i] == 1 and tokens[i] != tokenizer.pad_token:  # 是添加的实体
                #         new_labels.append(labels_map['[ENT]'])
                #     elif tag[i] == 2:
                #         if args.use_subword_tag:
                #             new_labels.append(labels_map['[X]'])
                #         else:
                #             new_labels.append(labels_map['[ENT]'])
                #     else:
                #         new_labels.append(labels_map[PAD_TOKEN])

                # print(labels)
                # print(new_labels)
                # print([idx_to_label.get(key) for key in labels])
                # print([idx_to_label.get(key) for key in labels])
                # print(mask)
                # print(pos)
                # print(word_segment_ids)
                # print(tokens)
                # tokens = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokens + [tokenizer.sep_token])
                tokens = tokenizer.convert_tokens_to_ids(tokens)
                # print(tokens)
                # exit()
                assert len(tokens) == len(new_labels), AssertionError(
                    "The length of token and label is not matching")

                dataset.append(
                    [tokens, new_labels, mask, pos, vm, tag, word_segment_ids])

                # Enable dry rune
                if args.dry_run:
                    count += 1
                    if count == 100:
                        break

        return dataset

    # Evaluation function.
    def evaluate(args, is_test, final=False):
        if is_test:
            dataset = read_dataset(args.test_path)
        else:
            dataset = read_dataset(args.dev_path)

        input_ids = torch.LongTensor([sample[0] for sample in dataset])
        label_ids = torch.LongTensor([sample[1] for sample in dataset])
        mask_ids = torch.LongTensor([sample[2] for sample in dataset])
        pos_ids = torch.LongTensor([sample[3] for sample in dataset])
        vm_ids = torch.BoolTensor([sample[4] for sample in dataset])
        tag_ids = torch.LongTensor([sample[5] for sample in dataset])
        segment_ids = torch.LongTensor([sample[6] for sample in dataset])

        instances_num = input_ids.size(0)
        batch_size = args.batch_size

        if is_test:
            print("Batch size: ", batch_size)
            print("The number of test instances:", instances_num)

        correct = 0
        correct_with_type = 0
        gold_entities_num = 0
        pred_entities_num = 0

        confusion = torch.zeros(len(labels_map),
                                len(labels_map),
                                dtype=torch.long)

        model.eval()

        for i, (input_ids_batch, label_ids_batch, mask_ids_batch,
                pos_ids_batch, vm_ids_batch, tag_ids_batch,
                segment_ids_batch) in enumerate(
                    batch_loader(batch_size, input_ids, label_ids, mask_ids,
                                 pos_ids, vm_ids, tag_ids, segment_ids)):

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            pos_ids_batch = pos_ids_batch.to(device)
            tag_ids_batch = tag_ids_batch.to(device)
            vm_ids_batch = vm_ids_batch.long().to(device)
            segment_ids_batch = segment_ids_batch.long().to(device)

            loss, _, pred, gold, _ = model(input_ids_batch,
                                           segment_ids_batch,
                                           mask_ids_batch,
                                           label_ids_batch,
                                           pos_ids_batch,
                                           vm_ids_batch,
                                           use_kg=args.use_kg)

            if final:
                with open(f'{args.output_file_prefix}_predictions.txt', 'a') as p, \
                        open(f'{args.output_file_prefix}_gold.txt', 'a') as g, \
                        open(f'{args.output_file_prefix}_text.txt', 'a') as t:
                    predicted_labels = [
                        idx_to_label.get(key) for key in pred.tolist()
                    ]
                    gold_labels = [
                        idx_to_label.get(key) for key in gold.tolist()
                    ]

                    num_tokens = len(predicted_labels)
                    mask_ids_batch = mask_ids_batch.view(-1, num_tokens)
                    masks = mask_ids_batch.tolist()[0]
                    input_ids_batch = input_ids_batch.view(-1, num_tokens)
                    tokens = input_ids_batch.tolist()[0]

                    for start_idx in range(0, num_tokens, args.seq_length):
                        pred_sample = predicted_labels[start_idx:start_idx +
                                                       args.seq_length]
                        gold_sample = gold_labels[start_idx:start_idx +
                                                  args.seq_length]
                        mask = masks[start_idx:start_idx + args.seq_length]
                        num_labels = sum(mask)

                        token_sample = tokens[start_idx:start_idx +
                                              args.seq_length]
                        token_sample = token_sample[:num_labels]
                        text = ''.join(
                            tokenizer.convert_ids_to_tokens(token_sample))
                        text = bytearray([byte_decoder[c]
                                          for c in text]).decode('utf-8')

                        p.write(' '.join(pred_sample[:num_labels]) + '\n')
                        g.write(' '.join(gold_sample[:num_labels]) + '\n')
                        t.write(text + '\n')

            for j in range(gold.size()[0]):
                if gold[j].item() in begin_ids:
                    gold_entities_num += 1

            for j in range(pred.size()[0]):
                if pred[j].item(
                ) in begin_ids and gold[j].item() != labels_map["[PAD]"]:
                    pred_entities_num += 1

            pred_entities_pos = []
            pred_entities_pos_with_type = []
            gold_entities_pos = []
            gold_entities_pos_with_type = []
            start, end = 0, 0

            for j in range(gold.size()[0]):
                if gold[j].item() in begin_ids:
                    start = j
                    for k in range(j + 1, gold.size()[0]):
                        if gold[k].item() == labels_map['[X]'] or gold[k].item(
                        ) == labels_map['[ENT]']:
                            continue

                        if gold[k].item(
                        ) == labels_map["[PAD]"] or gold[k].item(
                        ) == labels_map["O"] or gold[k].item() in begin_ids:
                            end = k - 1
                            break
                    else:
                        end = gold.size()[0] - 1
                    if args.eval_range_with_types:
                        ent_type_gold = idx_to_label.get(gold[start].item())
                        ent_type_gold = ent_type_gold.replace('_NOKG', '')
                        gold_entities_pos_with_type.append(
                            (start, end, ent_type_gold))

                    gold_entities_pos.append((start, end))

            for j in range(pred.size()[0]):
                if pred[j].item() in begin_ids and gold[j].item() != labels_map["[PAD]"] and gold[j].item() != \
                        labels_map["[ENT]"] and gold[j].item() != labels_map["[X]"]:
                    start = j
                    for k in range(j + 1, pred.size()[0]):

                        if pred[k].item() == labels_map['[X]'] or gold[k].item(
                        ) == labels_map['[ENT]']:
                            continue

                        if pred[k].item(
                        ) == labels_map["[PAD]"] or pred[k].item(
                        ) == labels_map["O"] or pred[k].item() in begin_ids:
                            end = k - 1
                            break
                    else:
                        end = pred.size()[0] - 1

                    if args.eval_range_with_types:
                        # Get all the labels in the range
                        if start == end:
                            entity_types = [
                                idx_to_label.get(l.item())
                                for l in [pred[start]]
                            ]
                        else:
                            entity_types = [
                                idx_to_label.get(l.item())
                                for l in pred[start:end]
                            ]

                        # Run voting choicer
                        final_entity_type = voting_choicer(entity_types)
                        final_entity_type = final_entity_type.replace(
                            '_NOKG', '')

                        if final:
                            logger.info(
                                f'Predicted: {" ".join(entity_types)}, Selected: {final_entity_type}'
                            )
                        if args.voting_choicer:
                            # Convert back to label id and add in the tuple
                            pred_entities_pos_with_type.append(
                                (start, end, final_entity_type))
                        else:
                            # Use the first prediction
                            ent_type_pred = idx_to_label.get(
                                pred[start].item())
                            ent_type_pred = ent_type_pred.replace('_NOKG', '')
                            pred_entities_pos_with_type.append(
                                (start, end, ent_type_pred))

                    pred_entities_pos.append((start, end))

            for entity in pred_entities_pos:
                if entity not in gold_entities_pos:
                    continue
                else:
                    correct += 1
            if args.eval_range_with_types:
                for entity in pred_entities_pos_with_type:
                    if entity not in gold_entities_pos_with_type:
                        continue
                    else:
                        correct_with_type += 1

        try:
            print("Report precision, recall, and f1:")
            p = correct / pred_entities_num
            r = correct / gold_entities_num
            f1 = 2 * p * r / (p + r)
            print("{:.3f}, {:.3f}, {:.3f}".format(p, r, f1))

            if args.eval_range_with_types:
                try:
                    print(
                        "Report accuracy with type, precision, recall, and f1:"
                    )
                    p_with_type = correct_with_type / pred_entities_num
                    r_with_type = correct_with_type / gold_entities_num
                    f1_with_type = 2 * p_with_type * r_with_type / (
                        p_with_type + r_with_type)
                    print("{:.3f}, {:.3f}, {:.3f}".format(
                        p_with_type, r_with_type, f1_with_type))
                except:
                    pass
            return f1
        except ZeroDivisionError:
            return 0

    # Training phase.
    print("Start training.")
    instances = read_dataset(args.train_path)

    input_ids = torch.LongTensor([ins[0] for ins in instances])
    label_ids = torch.LongTensor([ins[1] for ins in instances])
    mask_ids = torch.LongTensor([ins[2] for ins in instances])
    pos_ids = torch.LongTensor([ins[3] for ins in instances])
    vm_ids = torch.BoolTensor([ins[4] for ins in instances])
    tag_ids = torch.LongTensor([ins[5] for ins in instances])
    segment_ids = torch.LongTensor([ins[6] for ins in instances])

    instances_num = input_ids.size(0)
    batch_size = args.batch_size
    train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    train_batcher = Batcher(batch_size, input_ids, label_ids, mask_ids,
                            pos_ids, vm_ids, tag_ids, segment_ids)

    print("Batch size: ", batch_size)
    print("The number of training instances:", instances_num)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup,
                         t_total=train_steps)

    total_loss = 0.
    f1 = 0.0
    best_f1 = 0.0

    # Dry evaluate
    # evaluate(args, True)

    for epoch in range(1, args.epochs_num + 1):
        model.train()
        for i, (input_ids_batch, label_ids_batch, mask_ids_batch,
                pos_ids_batch, vm_ids_batch, tag_ids_batch,
                segment_ids_batch) in enumerate(train_batcher):
            model.zero_grad()

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            pos_ids_batch = pos_ids_batch.to(device)
            tag_ids_batch = tag_ids_batch.to(device)
            vm_ids_batch = vm_ids_batch.long().to(device)
            segment_ids_batch = segment_ids_batch.long().to(device)

            loss, _, _, _, _ = model(input_ids_batch,
                                     segment_ids_batch,
                                     mask_ids_batch,
                                     label_ids_batch,
                                     pos_ids_batch,
                                     vm_ids_batch,
                                     use_kg=args.use_kg)

            if torch.cuda.device_count() > 1:
                loss = torch.mean(loss)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".
                      format(epoch, i + 1, total_loss / args.report_steps))
                total_loss = 0.

            loss.backward()
            optimizer.step()

        # Evaluation phase.
        print("Start evaluate on dev dataset.")
        f1 = evaluate(args, False)
        print("Start evaluation on test dataset.")
        evaluate(args, True)

        if f1 > best_f1:
            best_f1 = f1
            save_model(model, args.output_model_path)
            save_encoder(args, encoder, suffix=args.suffix_file_encoder)
        else:
            continue

    # Evaluation phase.
    print("Final evaluation on test dataset.")

    if torch.cuda.device_count() > 1:
        model.module.load_state_dict(torch.load(args.output_model_path))
    else:
        model.load_state_dict(torch.load(args.output_model_path))

    evaluate(args, True, final=True)
コード例 #4
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path",
                        default=None,
                        type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path",
                        default="./models/tagger_model.bin",
                        type=str,
                        help="Path of the output model.")
    parser.add_argument("--output_encoder",
                        default="./luke-models/",
                        type=str,
                        help="Path of the output luke model.")
    parser.add_argument("--suffix_file_encoder",
                        default="encoder",
                        type=str,
                        help="output file suffix luke model.")
    parser.add_argument("--vocab_path",
                        default="./models/google_vocab.txt",
                        type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path",
                        type=str,
                        required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path",
                        type=str,
                        required=True,
                        help="Path of the devset.")
    parser.add_argument("--test_path",
                        type=str,
                        required=True,
                        help="Path of the testset.")
    parser.add_argument("--config_path",
                        default="./models/google_config.json",
                        type=str,
                        help="Path of the config file.")
    parser.add_argument("--output_file_prefix",
                        type=str,
                        required=True,
                        help="Prefix for file output.")
    parser.add_argument("--log_file", default='app.log')

    # Model options.
    parser.add_argument("--batch_size",
                        type=int,
                        default=2,
                        help="Batch_size.")
    parser.add_argument("--seq_length",
                        default=256,
                        type=int,
                        help="Sequence length.")
    parser.add_argument("--classifier",
                        choices=["mlp", "lstm", "lstm_crf", "lstm_ncrf"],
                        default="mlp",
                        help="Classifier type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")
    parser.add_argument('--freeze_encoder_weights',
                        action='store_true',
                        help="Enable to freeze the encoder weigths.")

    # Subword options.
    parser.add_argument("--subword_type",
                        choices=["none", "char"],
                        default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path",
                        type=str,
                        default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder",
                        choices=["avg", "lstm", "gru", "cnn"],
                        default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num",
                        type=int,
                        default=2,
                        help="The number of subencoder layers.")

    # Optimizer options.
    parser.add_argument("--learning_rate",
                        type=float,
                        default=2e-5,
                        help="Learning rate.")
    parser.add_argument("--schedule_lr",
                        action='store_true',
                        help="Enable to use lr scheduler.")
    parser.add_argument("--warmup",
                        type=float,
                        default=0.1,
                        help="Warm up value.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout.")
    parser.add_argument("--epochs_num",
                        type=int,
                        default=5,
                        help="Number of epochs.")
    parser.add_argument("--report_steps",
                        type=int,
                        default=2,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=35, help="Random seed.")

    # kg
    parser.add_argument("--kg_name", required=True, help="KG name or path")
    parser.add_argument("--use_kg",
                        action='store_true',
                        help="Enable the use of KG.")
    parser.add_argument("--dry_run",
                        action='store_true',
                        help="Dry run to test the implementation.")
    parser.add_argument(
        "--voting_choicer",
        action='store_true',
        help="Enable the Voting choicer to select the entity type.")
    parser.add_argument("--eval_kg_tag",
                        action='store_true',
                        help="Enable to include [ENT] tag in evaluation.")
    parser.add_argument("--use_subword_tag",
                        action='store_true',
                        help="Enable to use separate tag for subword splits.")
    parser.add_argument("--debug", action='store_true', help="Enable debug.")
    parser.add_argument("--reverse_order",
                        action='store_true',
                        help="Reverse the feature selection order.")
    parser.add_argument("--max_entities",
                        default=2,
                        type=int,
                        help="Number of KG features.")
    parser.add_argument("--eval_range_with_types",
                        action='store_true',
                        help="Enable to eval range with types.")

    args = parser.parse_args()

    # Load the hyperparameters of the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    logging.basicConfig(filename=args.log_file, filemode='w', format=fmt)

    labels_map = {"[PAD]": 0, "[ENT]": 1, "[X]": 2, "[CLS]": 3, "[SEP]": 4}
    begin_ids = []

    # Find tagging labels
    for file in (args.train_path, args.dev_path, args.test_path):
        with open(file, mode="r", encoding="utf-8") as f:
            for line_id, line in enumerate(f):
                if line_id == 0:
                    continue
                labels = line.strip().split("\t")[0].split()
                for l in labels:
                    if l not in labels_map:
                        if l.startswith("B") or l.startswith("S"):
                            begin_ids.append(len(labels_map))
                            # check if I-TAG exists
                            infix = l[1]
                            tag = l[2:]
                            inner_tag = f'I{infix}{tag}'
                            if inner_tag not in labels_map:
                                labels_map[inner_tag] = len(labels_map)

                        labels_map[l] = len(labels_map)

    idx_to_label = {labels_map[key]: key for key in labels_map}

    print(begin_ids)
    print("Labels: ", labels_map)
    args.labels_num = len(labels_map)

    # Build knowledge graph.
    if args.kg_name == 'none':
        kg_file = []
    else:
        kg_file = args.kg_name

    # Load Luke model.
    model_archive = ModelArchive.load(args.pretrained_model_path)
    tokenizer = model_archive.tokenizer

    # Handling space character in roberta tokenizer
    byte_encoder = bytes_to_unicode()
    byte_decoder = {v: k for k, v in byte_encoder.items()}

    # Load the pretrained model
    encoder = LukeModel(model_archive.config)
    encoder.load_state_dict(model_archive.state_dict, strict=False)

    kg = KnowledgeGraph(kg_file=kg_file, tokenizer=tokenizer)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    # Build sequence labeling model.
    classifiers = {
        "mlp": LukeTaggerMLP,
        "lstm": LukeTaggerLSTM,
        "lstm_crf": LukeTaggerLSTMCRF,
        "lstm_ncrf": LukeTaggerLSTMNCRF
    }
    logger.info(f'The selected classifier is:{classifiers[args.classifier]}')
    model = classifiers[args.classifier](args, encoder)
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = nn.DataParallel(model)
    model = model.to(device)

    # Datset loader.
    def batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids,
                     vm_ids, tag_ids, segment_ids):
        instances_num = input_ids.size()[0]
        for i in range(instances_num // batch_size):
            input_ids_batch = input_ids[i * batch_size:(i + 1) * batch_size, :]
            label_ids_batch = label_ids[i * batch_size:(i + 1) * batch_size, :]
            mask_ids_batch = mask_ids[i * batch_size:(i + 1) * batch_size, :]
            pos_ids_batch = pos_ids[i * batch_size:(i + 1) * batch_size, :]
            vm_ids_batch = vm_ids[i * batch_size:(i + 1) * batch_size, :, :]
            tag_ids_batch = tag_ids[i * batch_size:(i + 1) * batch_size, :]
            segment_ids_batch = segment_ids[i * batch_size:(i + 1) *
                                            batch_size, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch, segment_ids_batch
        if instances_num > instances_num // batch_size * batch_size:
            input_ids_batch = input_ids[instances_num // batch_size *
                                        batch_size:, :]
            label_ids_batch = label_ids[instances_num // batch_size *
                                        batch_size:, :]
            mask_ids_batch = mask_ids[instances_num // batch_size *
                                      batch_size:, :]
            pos_ids_batch = pos_ids[instances_num // batch_size *
                                    batch_size:, :]
            vm_ids_batch = vm_ids[instances_num // batch_size *
                                  batch_size:, :, :]
            tag_ids_batch = tag_ids[instances_num // batch_size *
                                    batch_size:, :]
            segment_ids_batch = segment_ids[instances_num // batch_size *
                                            batch_size:, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch, segment_ids_batch

    # Read dataset.
    def read_dataset(path):
        dataset = []
        count = 0
        with open(path, mode="r", encoding="utf8") as f:
            f.readline()
            tokens, labels = [], []
            for line_id, line in enumerate(f):
                fields = line.strip().split("\t")
                if len(fields) == 2:
                    labels, tokens = fields
                elif len(fields) == 3:
                    labels, tokens, cls = fields
                else:
                    print(
                        f'The data is not in accepted format at line no:{line_id}.. Ignored'
                    )
                    continue

                tokens, pos, vm, tag = \
                    kg.add_knowledge_with_vm([tokens], [labels],
                                             use_kg=args.use_kg,
                                             max_length=args.seq_length,
                                             max_entities=args.max_entities,
                                             reverse_order=args.reverse_order)
                tokens = tokens[0]
                pos = pos[0]
                vm = vm[0].astype("bool")
                tag = tag[0]

                # tokens = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokens + [tokenizer.sep_token])
                non_pad_tokens = [
                    tok for tok in tokens if tok != tokenizer.pad_token
                ]
                num_tokens = len(non_pad_tokens)
                num_pad = len(tokens) - num_tokens

                labels = [config.CLS_TOKEN
                          ] + labels.split(" ") + [config.SEP_TOKEN]
                new_labels = []
                j = 0
                joiner = '-'
                for i in range(len(tokens)):
                    if tag[i] == 0 and tokens[i] != tokenizer.pad_token:
                        cur_type = labels[j]
                        if cur_type != 'O':
                            try:
                                joiner = cur_type[1]
                                prev_label = cur_type[2:]
                            except:
                                logger.info(
                                    f'The label:{cur_type} is converted to O')
                                prev_label = 'O'
                                j += 1
                                new_labels.append('O')
                                continue
                        else:
                            prev_label = cur_type

                        new_labels.append(cur_type)
                        j += 1

                    elif tag[i] == 1 and tokens[
                            i] != tokenizer.pad_token:  # 是添加的实体
                        new_labels.append('[ENT]')
                    elif tag[i] == 2:
                        if prev_label == 'O':
                            new_labels.append('O')
                        else:
                            if args.use_subword_tag:
                                new_labels.append('[X]')
                            else:
                                new_labels.append(f'I{joiner}' + prev_label)
                    else:
                        new_labels.append(PAD_TOKEN)

                new_labels = [labels_map[l] for l in new_labels]

                # print(tokens)
                # print(labels)
                # print(tag)

                mask = [1] * (num_tokens) + [0] * num_pad
                word_segment_ids = [0] * (len(tokens))

                # print(len(tokens))
                # print(len(tag))
                # exit()
                # print(tokenizer.pad_token_id)

                # for i in range(len(tokens)):
                #     if tag[i] == 0 and tokens[i] != tokenizer.pad_token:
                #         new_labels.append(labels[j])
                #         j += 1
                #     elif tag[i] == 1 and tokens[i] != tokenizer.pad_token:  # 是添加的实体
                #         new_labels.append(labels_map['[ENT]'])
                #     elif tag[i] == 2:
                #         if args.use_subword_tag:
                #             new_labels.append(labels_map['[X]'])
                #         else:
                #             new_labels.append(labels_map['[ENT]'])
                #     else:
                #         new_labels.append(labels_map[PAD_TOKEN])

                # print(labels)
                # print(new_labels)
                # print([idx_to_label.get(key) for key in labels])
                # print([idx_to_label.get(key) for key in labels])
                # print(mask)
                # print(pos)
                # print(word_segment_ids)
                # print(tokens)
                # tokens = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokens + [tokenizer.sep_token])
                tokens = tokenizer.convert_tokens_to_ids(tokens)
                # print(tokens)
                # exit()
                assert len(tokens) == len(new_labels), AssertionError(
                    "The length of token and label is not matching")

                dataset.append(
                    [tokens, new_labels, mask, pos, vm, tag, word_segment_ids])

                # Enable dry rune
                if args.dry_run:
                    count += 1
                    if count == 100:
                        break

        return dataset

    # Evaluation function.
    def evaluate(args, is_test, final=False):
        if is_test:
            dataset = read_dataset(args.test_path)
        else:
            dataset = read_dataset(args.dev_path)

        input_ids = torch.LongTensor([sample[0] for sample in dataset])
        label_ids = torch.LongTensor([sample[1] for sample in dataset])
        mask_ids = torch.LongTensor([sample[2] for sample in dataset])
        pos_ids = torch.LongTensor([sample[3] for sample in dataset])
        vm_ids = torch.BoolTensor([sample[4] for sample in dataset])
        tag_ids = torch.LongTensor([sample[5] for sample in dataset])
        segment_ids = torch.LongTensor([sample[6] for sample in dataset])

        instances_num = input_ids.size(0)
        batch_size = args.batch_size

        if is_test:
            logger.info(f"Batch size:{batch_size}")
            print(f"The number of test instances:{instances_num}")

        true_labels_all = []
        predicted_labels_all = []
        confusion = torch.zeros(len(labels_map),
                                len(labels_map),
                                dtype=torch.long)
        model.eval()

        for i, (input_ids_batch, label_ids_batch, mask_ids_batch,
                pos_ids_batch, vm_ids_batch, tag_ids_batch,
                segment_ids_batch) in enumerate(
                    batch_loader(batch_size, input_ids, label_ids, mask_ids,
                                 pos_ids, vm_ids, tag_ids, segment_ids)):

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            pos_ids_batch = pos_ids_batch.to(device)
            tag_ids_batch = tag_ids_batch.to(device)
            vm_ids_batch = vm_ids_batch.long().to(device)
            segment_ids_batch = segment_ids_batch.long().to(device)

            pred = model(input_ids_batch,
                         segment_ids_batch,
                         mask_ids_batch,
                         label_ids_batch,
                         pos_ids_batch,
                         vm_ids_batch,
                         use_kg=args.use_kg)

            for pred_sample, gold_sample, mask in zip(pred, label_ids_batch,
                                                      mask_ids_batch):

                pred_labels = [
                    idx_to_label.get(key) for key in pred_sample.tolist()
                ]
                gold_labels = [
                    idx_to_label.get(key) for key in gold_sample.tolist()
                ]

                num_labels = sum(mask)

                # Exclude the [CLS], and [SEP] tokens
                pred_labels = pred_labels[1:num_labels - 1]
                true_labels = gold_labels[1:num_labels - 1]

                pred_labels = [p.replace('_NOKG', '') for p in pred_labels]
                true_labels = [t.replace('_NOKG', '') for t in true_labels]

                true_labels, pred_labels = filter_kg_labels(
                    true_labels, pred_labels)

                pred_labels = [p.replace('_', '-') for p in pred_labels]
                true_labels = [t.replace('_', '-') for t in true_labels]

                biluo_tags_predicted = get_bio(pred_labels)
                biluo_tags_true = get_bio(true_labels)

                if len(biluo_tags_predicted) != len(biluo_tags_true):
                    logger.error(
                        'The length of the predicted labels is not same as that of true labels..'
                    )
                    exit()

                predicted_labels_all.append(biluo_tags_predicted)
                true_labels_all.append(biluo_tags_true)

        if final:
            with open(f'{args.output_file_prefix}_predictions.txt', 'a') as p, \
                    open(f'{args.output_file_prefix}_gold.txt', 'a') as g:
                p.write('\n'.join([' '.join(l) for l in predicted_labels_all]))
                g.write('\n'.join([' '.join(l) for l in true_labels_all]))

        return dict(
            f1=seqeval.metrics.f1_score(true_labels_all, predicted_labels_all),
            precision=seqeval.metrics.precision_score(true_labels_all,
                                                      predicted_labels_all),
            recall=seqeval.metrics.recall_score(true_labels_all,
                                                predicted_labels_all),
            f1_span=f1_score_span(true_labels_all, predicted_labels_all),
            precision_span=precision_score_span(true_labels_all,
                                                predicted_labels_all),
            recall_span=recall_score_span(true_labels_all,
                                          predicted_labels_all),
        )

    # Training phase.
    logger.info("Start training.")
    instances = read_dataset(args.train_path)

    input_ids = torch.LongTensor([ins[0] for ins in instances])
    label_ids = torch.LongTensor([ins[1] for ins in instances])
    mask_ids = torch.LongTensor([ins[2] for ins in instances])
    pos_ids = torch.LongTensor([ins[3] for ins in instances])
    vm_ids = torch.BoolTensor([ins[4] for ins in instances])
    tag_ids = torch.LongTensor([ins[5] for ins in instances])
    segment_ids = torch.LongTensor([ins[6] for ins in instances])

    instances_num = input_ids.size(0)
    batch_size = args.batch_size
    train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    train_batcher = Batcher(batch_size, input_ids, label_ids, mask_ids,
                            pos_ids, vm_ids, tag_ids, segment_ids)

    logger.info(f"Batch size:{batch_size}")
    logger.info(f"The number of training instances:{instances_num}")

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup,
                         t_total=train_steps)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                               T_max=args.epochs_num)
    total_loss = 0.
    best_f1 = 0.0

    # Dry evaluate
    # evaluate(args, True)

    for epoch in range(1, args.epochs_num + 1):
        model.train()
        for i, (input_ids_batch, label_ids_batch, mask_ids_batch,
                pos_ids_batch, vm_ids_batch, tag_ids_batch,
                segment_ids_batch) in enumerate(train_batcher):
            model.zero_grad()

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            pos_ids_batch = pos_ids_batch.to(device)
            tag_ids_batch = tag_ids_batch.to(device)
            vm_ids_batch = vm_ids_batch.long().to(device)
            segment_ids_batch = segment_ids_batch.long().to(device)

            loss = model.score(input_ids_batch,
                               segment_ids_batch,
                               mask_ids_batch,
                               label_ids_batch,
                               pos_ids_batch,
                               vm_ids_batch,
                               use_kg=args.use_kg)

            if torch.cuda.device_count() > 1:
                loss = torch.mean(loss)
            total_loss += loss.item()

            if (i + 1) % args.report_steps == 0:
                logger.info(
                    "Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".
                    format(epoch, i + 1, total_loss / args.report_steps))
                total_loss = 0.
            loss.backward()
            optimizer.step()

            if args.schedule_lr:
                # Update learning rate
                scheduler.step()

        # Evaluation phase.
        logger.info("Start evaluate on dev dataset.")
        results = evaluate(args, False)
        logger.info(results)

        logger.info("Start evaluation on test dataset.")
        results_test = evaluate(args, True)
        logger.info(results_test)

        if results['f1'] > best_f1:
            best_f1 = results['f1']
            save_model(model, args.output_model_path)
            save_encoder(args, encoder, suffix=args.suffix_file_encoder)
        else:
            continue

    # Evaluation phase.
    logger.info("Final evaluation on test dataset.")
    if torch.cuda.device_count() > 1:
        model.module.load_state_dict(torch.load(args.output_model_path))
    else:
        model.load_state_dict(torch.load(args.output_model_path))
    results_final = evaluate(args, True, final=True)
    logger.info(results_final)