Пример #1
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path",
                        default=None,
                        type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path",
                        default="./models/cmrc_model.bin",
                        type=str,
                        help="Path of the output model.")
    parser.add_argument("--vocab_path",
                        type=str,
                        required=True,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path",
                        type=str,
                        required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path",
                        type=str,
                        required=True,
                        help="Path of the devset.")
    parser.add_argument("--test_path", type=str, help="Path of the testset.")
    parser.add_argument("--config_path",
                        default="./models/bert_base_config.json",
                        type=str,
                        help="Path of the config file.")

    # Model options.
    parser.add_argument("--batch_size",
                        type=int,
                        default=64,
                        help="Batch size.")
    parser.add_argument("--seq_length",
                        type=int,
                        default=100,
                        help="Sequence length.")
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument("--embedding",
                        choices=["bert", "word"],
                        default="bert",
                        help="Emebdding type.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                              "cnn", "gatedcnn", "attn", "synt", \
                                              "rcnn", "crnn", "gpt", "bilstm"], \
                                              default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")
    parser.add_argument("--factorized_embedding_parameterization",
                        action="store_true",
                        help="Factorized embedding parameterization.")
    parser.add_argument("--parameter_sharing",
                        action="store_true",
                        help="Parameter sharing.")

    # Optimizer options.
    parser.add_argument("--learning_rate",
                        type=float,
                        default=3e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup",
                        type=float,
                        default=0.1,
                        help="Warm up value.")
    parser.add_argument(
        "--fp16",
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        "--fp16_opt_level",
        choices=["O0", "O1", "O2", "O3"],
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.5, help="Dropout.")
    parser.add_argument("--epochs_num",
                        type=int,
                        default=3,
                        help="Number of epochs.")
    parser.add_argument("--report_steps",
                        type=int,
                        default=100,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7, help="Random seed.")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build machine reading comprehension model.
    model = MachineReadingComprehension(args)

    # Load or initialize parameters.
    load_or_initialize_parameters(args, model)

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(args.device)
    args.model = model

    args.tokenizer = CharTokenizer(args)

    # Training phase.
    batch_size = args.batch_size
    print("Batch size: ", batch_size)
    trainset, _ = read_dataset(args, args.train_path)
    random.shuffle(trainset)
    instances_num = len(trainset)

    src = torch.LongTensor([sample[0] for sample in trainset])
    seg = torch.LongTensor([sample[1] for sample in trainset])
    start_position = torch.LongTensor([sample[2] for sample in trainset])
    end_position = torch.LongTensor([sample[3] for sample in trainset])

    args.train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    print("The number of training instances:", instances_num)

    optimizer, scheduler = build_optimizer(args, model)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)

    total_loss = 0.
    result = 0.0
    best_result = 0.0

    print("Start training.")

    for epoch in range(1, args.epochs_num + 1):
        model.train()

        for i, (src_batch, seg_batch, start_position_batch,
                end_position_batch) in enumerate(
                    batch_loader(batch_size, src, seg, start_position,
                                 end_position)):
            loss = train(args, model, optimizer, scheduler, src_batch,
                         seg_batch, start_position_batch, end_position_batch)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".
                      format(epoch, i + 1, total_loss / args.report_steps))
                total_loss = 0.

        result = evaluate(args, *read_dataset(args, args.dev_path))
        if result > best_result:
            best_result = result
            save_model(model, args.output_model_path)
        else:
            break

    # Evaluation phase.
    if args.test_path is not None:
        print("Test set evaluation.")
        if torch.cuda.device_count() > 1:
            model.module.load_state_dict(torch.load(args.output_model_path))
        else:
            model.load_state_dict(torch.load(args.output_model_path))
        evaluate(args, *read_dataset(args, args.test_path))
Пример #2
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path",
                        default=None,
                        type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path",
                        default="./models/multichoice_model.bin",
                        type=str,
                        help="Path of the output model.")
    parser.add_argument("--vocab_path",
                        default=None,
                        type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--spm_model_path",
                        default=None,
                        type=str,
                        help="Path of the sentence piece model.")
    parser.add_argument("--train_path",
                        type=str,
                        required=True,
                        help="Path of the trainset.")
    parser.add_argument("--train_answer_path",
                        type=str,
                        required=True,
                        help="Path of the answers for trainset.")
    parser.add_argument("--dev_path",
                        type=str,
                        required=True,
                        help="Path of the devset.")
    parser.add_argument("--dev_answer_path",
                        type=str,
                        required=True,
                        help="Path of the answers for devset.")
    parser.add_argument("--config_path",
                        default="./models/bert_base_config.json",
                        type=str,
                        help="Path of the config file.")

    # Model options.
    parser.add_argument("--batch_size",
                        type=int,
                        default=4,
                        help="Batch size.")
    parser.add_argument("--seq_length",
                        type=int,
                        default=128,
                        help="Sequence length.")
    parser.add_argument("--embedding",
                        choices=["bert", "word"],
                        default="bert",
                        help="Emebdding type.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                              "cnn", "gatedcnn", "attn", "synt", \
                                              "rcnn", "crnn", "gpt", "bilstm"], \
                                              default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")
    parser.add_argument("--factorized_embedding_parameterization",
                        action="store_true",
                        help="Factorized embedding parameterization.")
    parser.add_argument("--parameter_sharing",
                        action="store_true",
                        help="Parameter sharing.")
    parser.add_argument(
        "--max_choices_num",
        default=10,
        type=int,
        help=
        "The maximum number of cadicate answer, shorter than this will be padded."
    )

    # Tokenizer options.
    parser.add_argument(
        "--tokenizer",
        choices=["bert", "char", "space"],
        default="bert",
        help="Specify the tokenizer."
        "Original Google BERT uses bert tokenizer on Chinese corpus."
        "Char tokenizer segments sentences into characters."
        "Space tokenizer segments sentences into words according to space.")

    # Optimizer options.
    parser.add_argument("--learning_rate",
                        type=float,
                        default=2e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup",
                        type=float,
                        default=0.1,
                        help="Warm up value.")
    parser.add_argument(
        "--fp16",
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit."
    )
    parser.add_argument(
        "--fp16_opt_level",
        choices=["O0", "O1", "O2", "O3"],
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.2, help="Dropout.")
    parser.add_argument("--epochs_num",
                        type=int,
                        default=8,
                        help="Number of epochs.")
    parser.add_argument("--report_steps",
                        type=int,
                        default=100,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7, help="Random seed.")

    args = parser.parse_args()
    args.labels_num = args.max_choices_num

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    # Build tokenizer.
    args.tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"](args)

    # Build multiple choice model.
    model = MultipleChoice(args)

    # Load or initialize parameters.
    load_or_initialize_parameters(args, model)

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(args.device)

    # Training phase.
    trainset = read_dataset(args, args.train_path, args.train_answer_path)
    random.shuffle(trainset)
    instances_num = len(trainset)
    batch_size = args.batch_size

    src = torch.LongTensor([example[0] for example in trainset])
    tgt = torch.LongTensor([example[1] for example in trainset])
    seg = torch.LongTensor([example[2] for example in trainset])

    args.train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    print("Batch size: ", batch_size)
    print("The number of training instances:", instances_num)

    optimizer, scheduler = build_optimizer(args, model)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)
        args.amp = amp

    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)
    args.model = model

    total_loss, result, best_result = 0., 0., 0.

    print("Start training.")

    for epoch in range(1, args.epochs_num + 1):
        model.train()
        for i, (src_batch, tgt_batch, seg_batch,
                _) in enumerate(batch_loader(batch_size, src, tgt, seg)):

            loss = train_model(args, model, optimizer, scheduler, src_batch,
                               tgt_batch, seg_batch)
            total_loss += loss.item()

            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".
                      format(epoch, i + 1, total_loss / args.report_steps))
                total_loss = 0.

        result = evaluate(
            args, read_dataset(args, args.dev_path, args.dev_answer_path))
        if result[0] > best_result:
            best_result = result[0]
            save_model(model, args.output_model_path)
Пример #3
0
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path", default=None, type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path", default="./models/tagger_model.bin", type=str,
                        help="Path of the output model.")
    parser.add_argument("--vocab_path", default="./models/google_vocab.txt", type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path", type=str, required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path", type=str, required=True,
                        help="Path of the devset.")
    parser.add_argument("--test_path", type=str, required=True,
                        help="Path of the testset.")
    parser.add_argument("--config_path", default="./models/google_config.json", type=str,
                        help="Path of the config file.")

    # Model options.
    parser.add_argument("--batch_size", type=int, default=64,
                        help="Batch_size.")
    parser.add_argument("--seq_length", default=128, type=int,
                        help="Sequence length.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                                   "cnn", "gatedcnn", "attn", \
                                                   "rcnn", "crnn", "gpt", "bilstm"], \
                                                   default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional", action="store_true", help="Specific to recurrent model.")
    
    # Subword options.
    parser.add_argument("--subword_type", choices=["none", "char"], default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path", type=str, default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder", choices=["avg", "lstm", "gru", "cnn"], default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num", type=int, default=2, help="The number of subencoder layers.")

    # Optimizer options.
    parser.add_argument("--learning_rate", type=float, default=2e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup", type=float, default=0.1,
                        help="Warm up value.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.1,
                        help="Dropout.")
    parser.add_argument("--epochs_num", type=int, default=3,
                        help="Number of epochs.")
    parser.add_argument("--report_steps", type=int, default=100,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7,
                        help="Random seed.")

    args = parser.parse_args()

    # Load the hyperparameters of the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

     # Find tagging labels.
    labels_map = {"NULL": 0, "O": 1} # ID for padding and non-entity.
    with open(args.train_path, mode="r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            if line_id == 0:
                continue
            line = line.strip().split()
            if len(line) != 2:
                continue
            if line[1] not in labels_map:
                labels_map[line[1]] = len(labels_map)

    print("Labels: ", labels_map)
    args.labels_num = len(labels_map)

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build bert model.
    # A pseudo target is added.
    args.target = "bert"
    model = build_model(args)

    # Load or initialize parameters.
    if args.pretrained_model_path is not None:
        # Initialize with pretrained model.
        model.load_state_dict(torch.load(args.pretrained_model_path), strict=False)  
    else:
        # Initialize with normal distribution.
        for n, p in list(model.named_parameters()):
            if 'gamma' not in n and 'beta' not in n:
                p.data.normal_(0, 0.02)
    
    # Build sequence labeling model.
    model = BertTagger(args, model)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)

    model = model.to(device)

    # Datset loader.
    def batch_loader(batch_size, input_ids, label_ids, mask_ids):
        instances_num = input_ids.size()[0]
        for i in range(instances_num // batch_size):
            input_ids_batch = input_ids[i*batch_size: (i+1)*batch_size, :]
            label_ids_batch = label_ids[i*batch_size: (i+1)*batch_size, :]
            mask_ids_batch = mask_ids[i*batch_size: (i+1)*batch_size, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch
        if instances_num > instances_num // batch_size * batch_size:
            input_ids_batch = input_ids[instances_num//batch_size*batch_size:, :]
            label_ids_batch = label_ids[instances_num//batch_size*batch_size:, :]
            mask_ids_batch = mask_ids[instances_num//batch_size*batch_size:, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch

    # Read dataset.
    def read_dataset(path):
        dataset = []
        with open(path, mode="r", encoding="utf-8") as f:
            tokens, labels = [], []
            for line_id, line in enumerate(f):
                if line_id == 0:
                    continue
                line = line.strip().split()
                if len(line) != 2:
                    assert len(tokens) == len(labels)
                    tokens = [vocab.get(t) for t in tokens]
                    labels = [labels_map[l] for l in labels]
                    mask = [1] * len(tokens)
                    if len(tokens) > args.seq_length:
                        tokens = tokens[:args.seq_length]
                        labels = labels[:args.seq_length]
                        mask = mask[:args.seq_length]
                    while len(tokens) < args.seq_length:
                        tokens.append(0)
                        labels.append(0)
                        mask.append(0)
                    dataset.append([tokens, labels, mask])

                    tokens, labels = [], []
                    continue
                tokens.append(line[0])
                labels.append(line[1])
        
        return dataset

    # Evaluation function.
    def evaluate(args, is_test):
        if is_test:
            dataset = read_dataset(args.test_path)
        else:
            dataset = read_dataset(args.dev_path)

        input_ids = torch.LongTensor([sample[0] for sample in dataset])
        label_ids = torch.LongTensor([sample[1] for sample in dataset])
        mask_ids = torch.LongTensor([sample[2] for sample in dataset])

        instances_num = input_ids.size(0)
        batch_size = args.batch_size

        if is_test:
            print("Batch size: ", batch_size)
            print("The number of test instances:", instances_num)

    
        correct = 0
        gold_entities_num = 0
        pred_entities_num = 0

        confusion = torch.zeros(len(labels_map), len(labels_map), dtype=torch.long)

        model.eval()

        for i, (input_ids_batch, label_ids_batch, mask_ids_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids)):
            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            loss, _, pred, gold = model(input_ids_batch, label_ids_batch, mask_ids_batch)
            
            # Gold.
            for j in range(gold.size()[0]):
                if (j > 0 and gold[j-1].item() <= 1 and gold[j].item() > 1) or (j == 0 and gold[j].item() > 1):
                    gold_entities_num += 1

            # Predict.
            for j in range(pred.size()[0]):
                if (j > 0 and pred[j-1].item() <= 1 and pred[j].item() > 1 and gold[j].item() != 0) or (j == 0 and pred[j].item() > 1):
                    pred_entities_num += 1

            pred_entities_pos = []
            gold_entities_pos = []
            start, end = 0, 0

            # Correct.
            for j in range(gold.size()[0]):
                if (j > 0 and gold[j-1].item() <= 1 and gold[j].item() > 1) or (j == 0 and gold[j].item() > 1):
                    start = j
                    for k in range(j, gold.size()[0]):
                        if gold[k].item() <= 1:
                            end = k - 1
                            break
                    else:
                        end = gold.size()[0] - 1
                    gold_entities_pos.append((start, end))

            # Predict.
            for j in range(pred.size()[0]):
                if (j > 0 and pred[j-1].item() <= 1 and pred[j].item() > 1) or (j == 0 and pred[j].item() > 1):
                        start = j
                        for k in range(j, pred.size()[0]):
                            if pred[k].item() <= 1:
                                end = k - 1
                                break
                        else:
                            end = pred.size()[0] - 1
                        pred_entities_pos.append((start, end))

            for entity in pred_entities_pos:
                if entity not in gold_entities_pos:
                    continue
                for j in range(entity[0], entity[1]+1):
                    if gold[j].item() != pred[j].item():
                        break
                else: 
                    correct += 1

        print("Report precision, recall, and f1:")
        p = correct/pred_entities_num
        r = correct/gold_entities_num
        f1 = 2*p*r/(p+r)
        print("{:.3f}, {:.3f}, {:.3f}".format(p,r,f1))

        return f1


    # Training phase.
    print("Start training.")
    instances = read_dataset(args.train_path)

    input_ids = torch.LongTensor([ins[0] for ins in instances])
    label_ids = torch.LongTensor([ins[1] for ins in instances])
    mask_ids = torch.LongTensor([ins[2] for ins in instances])

    instances_num = input_ids.size(0)
    batch_size = args.batch_size
    train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    print("Batch size: ", batch_size)
    print("The number of training instances:", instances_num)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
    ]
    optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup, t_total=train_steps)

    total_loss = 0.
    f1 = 0.0
    best_f1 = 0.0

    for epoch in range(1, args.epochs_num+1):
        model.train()
        for i, (input_ids_batch, label_ids_batch, mask_ids_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids)):
            model.zero_grad()

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)

            loss, _, _, _ = model(input_ids_batch, label_ids_batch, mask_ids_batch)
            if torch.cuda.device_count() > 1:
                loss = torch.mean(loss)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".format(epoch, i+1, total_loss / args.report_steps))
                total_loss = 0.

            loss.backward()
            optimizer.step()

        f1 = evaluate(args, False)
        if f1 > best_f1:
            best_f1 = f1
            save_model(model, args.output_model_path)
        else:
            break


    # Evaluation phase.
    print("Start evaluation.")

    if torch.cuda.device_count() > 1:
        model.module.load_state_dict(torch.load(args.output_model_path))
    else:
        model.load_state_dict(torch.load(args.output_model_path))

    evaluate(args, True)
Пример #4
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path",
                        default=None,
                        type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path",
                        default="./models/classifier_model.bin",
                        type=str,
                        help="Path of the output model.")
    parser.add_argument("--vocab_path",
                        default="./models/google_vocab.txt",
                        type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path",
                        type=str,
                        required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path",
                        type=str,
                        required=True,
                        help="Path of the devset.")
    parser.add_argument("--test_path",
                        type=str,
                        required=True,
                        help="Path of the testset.")
    parser.add_argument("--config_path",
                        default="./models/google_config.json",
                        type=str,
                        help="Path of the config file.")

    # Model options.
    parser.add_argument("--batch_size",
                        type=int,
                        default=32,
                        help="Batch size.")
    parser.add_argument("--seq_length",
                        type=int,
                        default=256,
                        help="Sequence length.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                                   "cnn", "gatedcnn", "attn", \
                                                   "rcnn", "crnn", "gpt", "bilstm"], \
                                                   default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")
    parser.add_argument("--pooling",
                        choices=["mean", "max", "first", "last"],
                        default="first",
                        help="Pooling type.")

    # Subword options.
    parser.add_argument("--subword_type",
                        choices=["none", "char"],
                        default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path",
                        type=str,
                        default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder",
                        choices=["avg", "lstm", "gru", "cnn"],
                        default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num",
                        type=int,
                        default=2,
                        help="The number of subencoder layers.")

    # Tokenizer options.
    parser.add_argument(
        "--tokenizer",
        choices=["bert", "char", "word", "space"],
        default="bert",
        help="Specify the tokenizer."
        "Original Google BERT uses bert tokenizer on Chinese corpus."
        "Char tokenizer segments sentences into characters."
        "Word tokenizer supports online word segmentation based on jieba segmentor."
        "Space tokenizer segments sentences into words according to space.")

    # Optimizer options.
    parser.add_argument("--learning_rate",
                        type=float,
                        default=2e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup",
                        type=float,
                        default=0.1,
                        help="Warm up value.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.5, help="Dropout.")
    parser.add_argument("--epochs_num",
                        type=int,
                        default=5,
                        help="Number of epochs.")
    parser.add_argument("--report_steps",
                        type=int,
                        default=100,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7, help="Random seed.")

    # Evaluation options.
    parser.add_argument("--mean_reciprocal_rank",
                        action="store_true",
                        help="Evaluation metrics for DBQA dataset.")

    # kg
    parser.add_argument("--kg_name", required=True, help="KG name or path")
    parser.add_argument("--workers_num",
                        type=int,
                        default=1,
                        help="number of process for loading dataset")
    parser.add_argument("--no_vm",
                        action="store_true",
                        help="Disable the visible_matrix")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    # Count the number of labels.
    labels_set = set()
    columns = {}
    with open(args.train_path, mode="r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            try:
                line = line.strip().split("\t")
                if line_id == 0:
                    for i, column_name in enumerate(line):
                        columns[column_name] = i
                    continue
                label = int(line[columns["label"]])
                labels_set.add(label)
            except:
                pass
    args.labels_num = len(labels_set)

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build bert model.
    # A pseudo target is added.
    args.target = "bert"
    model = build_model(args)

    # Load or initialize parameters.
    if args.pretrained_model_path is not None:
        # Initialize with pretrained model.
        model.load_state_dict(torch.load(args.pretrained_model_path),
                              strict=False)
    else:
        # Initialize with normal distribution.
        for n, p in list(model.named_parameters()):
            if 'gamma' not in n and 'beta' not in n:
                p.data.normal_(0, 0.02)

    # Build classification model.
    model = BertClassifier(args, model)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #device = torch.device("cpu")
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = nn.DataParallel(model)

    model = model.to(device)

    # Datset loader.
    def batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vms):
        instances_num = input_ids.size()[0]
        for i in range(instances_num // batch_size):
            input_ids_batch = input_ids[i * batch_size:(i + 1) * batch_size, :]
            label_ids_batch = label_ids[i * batch_size:(i + 1) * batch_size]
            mask_ids_batch = mask_ids[i * batch_size:(i + 1) * batch_size, :]
            pos_ids_batch = pos_ids[i * batch_size:(i + 1) * batch_size, :]
            vms_batch = vms[i * batch_size:(i + 1) * batch_size]
            yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch
        if instances_num > instances_num // batch_size * batch_size:
            input_ids_batch = input_ids[instances_num // batch_size *
                                        batch_size:, :]
            label_ids_batch = label_ids[instances_num // batch_size *
                                        batch_size:]
            mask_ids_batch = mask_ids[instances_num // batch_size *
                                      batch_size:, :]
            pos_ids_batch = pos_ids[instances_num // batch_size *
                                    batch_size:, :]
            vms_batch = vms[instances_num // batch_size * batch_size:]

            yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch

    # Build knowledge graph.
    if args.kg_name == 'none':
        spo_files = []
    else:
        spo_files = [args.kg_name]
    kg = KnowledgeGraph(spo_files=spo_files, predicate=True)

    def read_dataset(path, workers_num=1):

        print("Loading sentences from {}".format(path))
        sentences = []
        with open(path, mode='r', encoding="utf-8") as f:
            for line_id, line in enumerate(f):
                if line_id == 0:
                    continue
                sentences.append(line)
        sentence_num = len(sentences)

        print(
            "There are {} sentence in total. We use {} processes to inject knowledge into sentences."
            .format(sentence_num, workers_num))
        if workers_num > 1:
            params = []
            sentence_per_block = int(sentence_num / workers_num) + 1
            for i in range(workers_num):
                params.append((i, sentences[i * sentence_per_block:(i + 1) *
                                            sentence_per_block], columns, kg,
                               vocab, args))
            pool = Pool(workers_num)
            res = pool.map(add_knowledge_worker, params)
            pool.close()
            pool.join()
            dataset = [sample for block in res for sample in block]
        else:
            params = (0, sentences, columns, kg, vocab, args)
            dataset = add_knowledge_worker(params)

        return dataset

    # Evaluation function.
    def evaluate(args, is_test, metrics='Acc'):
        if is_test:
            dataset = read_dataset(args.test_path,
                                   workers_num=args.workers_num)
        else:
            dataset = read_dataset(args.dev_path, workers_num=args.workers_num)

        input_ids = torch.LongTensor([sample[0] for sample in dataset])
        label_ids = torch.LongTensor([sample[1] for sample in dataset])
        mask_ids = torch.LongTensor([sample[2] for sample in dataset])
        pos_ids = torch.LongTensor([example[3] for example in dataset])
        vms = [example[4] for example in dataset]

        batch_size = args.batch_size
        instances_num = input_ids.size()[0]
        if is_test:
            print("The number of evaluation instances: ", instances_num)

        correct = 0
        # Confusion matrix.
        confusion = torch.zeros(args.labels_num,
                                args.labels_num,
                                dtype=torch.long)

        model.eval()

        if not args.mean_reciprocal_rank:
            for i, (input_ids_batch, label_ids_batch, mask_ids_batch,
                    pos_ids_batch, vms_batch) in enumerate(
                        batch_loader(batch_size, input_ids, label_ids,
                                     mask_ids, pos_ids, vms)):

                # vms_batch = vms_batch.long()
                vms_batch = torch.LongTensor(vms_batch)

                input_ids_batch = input_ids_batch.to(device)
                label_ids_batch = label_ids_batch.to(device)
                mask_ids_batch = mask_ids_batch.to(device)
                pos_ids_batch = pos_ids_batch.to(device)
                vms_batch = vms_batch.to(device)

                with torch.no_grad():
                    try:
                        loss, logits = model(input_ids_batch, label_ids_batch,
                                             mask_ids_batch, pos_ids_batch,
                                             vms_batch)
                    except:
                        print(input_ids_batch)
                        print(input_ids_batch.size())
                        print(vms_batch)
                        print(vms_batch.size())

                logits = nn.Softmax(dim=1)(logits)
                pred = torch.argmax(logits, dim=1)
                gold = label_ids_batch
                for j in range(pred.size()[0]):
                    confusion[pred[j], gold[j]] += 1
                correct += torch.sum(pred == gold).item()

            if is_test:
                print("Confusion matrix:")
                print(confusion)
                print("Report precision, recall, and f1:")

            for i in range(confusion.size()[0]):
                p = confusion[i, i].item() / confusion[i, :].sum().item()
                r = confusion[i, i].item() / confusion[:, i].sum().item()
                f1 = 2 * p * r / (p + r)
                if i == 1:
                    label_1_f1 = f1
                print("Label {}: {:.3f}, {:.3f}, {:.3f}".format(i, p, r, f1))
            print("Acc. (Correct/Total): {:.4f} ({}/{}) ".format(
                correct / len(dataset), correct, len(dataset)))
            if metrics == 'Acc':
                return correct / len(dataset)
            elif metrics == 'f1':
                return label_1_f1
            else:
                return correct / len(dataset)
        else:
            for i, (input_ids_batch, label_ids_batch, mask_ids_batch,
                    pos_ids_batch, vms_batch) in enumerate(
                        batch_loader(batch_size, input_ids, label_ids,
                                     mask_ids, pos_ids, vms)):

                vms_batch = torch.LongTensor(vms_batch)

                input_ids_batch = input_ids_batch.to(device)
                label_ids_batch = label_ids_batch.to(device)
                mask_ids_batch = mask_ids_batch.to(device)
                pos_ids_batch = pos_ids_batch.to(device)
                vms_batch = vms_batch.to(device)

                with torch.no_grad():
                    loss, logits = model(input_ids_batch, label_ids_batch,
                                         mask_ids_batch, pos_ids_batch,
                                         vms_batch)
                logits = nn.Softmax(dim=1)(logits)
                if i == 0:
                    logits_all = logits
                if i >= 1:
                    logits_all = torch.cat((logits_all, logits), 0)

            order = -1
            gold = []
            for i in range(len(dataset)):
                qid = dataset[i][-1]
                label = dataset[i][1]
                if qid == order:
                    j += 1
                    if label == 1:
                        gold.append((qid, j))
                else:
                    order = qid
                    j = 0
                    if label == 1:
                        gold.append((qid, j))

            label_order = []
            order = -1
            for i in range(len(gold)):
                if gold[i][0] == order:
                    templist.append(gold[i][1])
                elif gold[i][0] != order:
                    order = gold[i][0]
                    if i > 0:
                        label_order.append(templist)
                    templist = []
                    templist.append(gold[i][1])
            label_order.append(templist)

            order = -1
            score_list = []
            for i in range(len(logits_all)):
                score = float(logits_all[i][1])
                qid = int(dataset[i][-1])
                if qid == order:
                    templist.append(score)
                else:
                    order = qid
                    if i > 0:
                        score_list.append(templist)
                    templist = []
                    templist.append(score)
            score_list.append(templist)

            rank = []
            pred = []
            print(len(score_list))
            print(len(label_order))
            for i in range(len(score_list)):
                if len(label_order[i]) == 1:
                    if label_order[i][0] < len(score_list[i]):
                        true_score = score_list[i][label_order[i][0]]
                        score_list[i].sort(reverse=True)
                        for j in range(len(score_list[i])):
                            if score_list[i][j] == true_score:
                                rank.append(1 / (j + 1))
                    else:
                        rank.append(0)

                else:
                    true_rank = len(score_list[i])
                    for k in range(len(label_order[i])):
                        if label_order[i][k] < len(score_list[i]):
                            true_score = score_list[i][label_order[i][k]]
                            temp = sorted(score_list[i], reverse=True)
                            for j in range(len(temp)):
                                if temp[j] == true_score:
                                    if j < true_rank:
                                        true_rank = j
                    if true_rank < len(score_list[i]):
                        rank.append(1 / (true_rank + 1))
                    else:
                        rank.append(0)
            MRR = sum(rank) / len(rank)
            print("MRR", MRR)
            return MRR

    # Training phase.
    print("Start training.")
    trainset = read_dataset(args.train_path, workers_num=args.workers_num)
    print("Shuffling dataset")
    random.shuffle(trainset)
    instances_num = len(trainset)
    batch_size = args.batch_size

    print("Trans data to tensor.")
    print("input_ids")
    input_ids = torch.LongTensor([example[0] for example in trainset])
    print("label_ids")
    label_ids = torch.LongTensor([example[1] for example in trainset])
    print("mask_ids")
    mask_ids = torch.LongTensor([example[2] for example in trainset])
    print("pos_ids")
    pos_ids = torch.LongTensor([example[3] for example in trainset])
    print("vms")
    vms = [example[4] for example in trainset]

    train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    print("Batch size: ", batch_size)
    print("The number of training instances:", instances_num)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup,
                         t_total=train_steps)

    total_loss = 0.
    result = 0.0
    best_result = 0.0

    all_acc = []
    all_loss = []
    epoch_loss = 0.

    for epoch in range(1, args.epochs_num + 1):
        model.train()
        for i, (input_ids_batch, label_ids_batch, mask_ids_batch,
                pos_ids_batch, vms_batch) in enumerate(
                    batch_loader(batch_size, input_ids, label_ids, mask_ids,
                                 pos_ids, vms)):
            model.zero_grad()

            vms_batch = torch.LongTensor(vms_batch)

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            pos_ids_batch = pos_ids_batch.to(device)
            vms_batch = vms_batch.to(device)

            loss, _ = model(input_ids_batch,
                            label_ids_batch,
                            mask_ids_batch,
                            pos=pos_ids_batch,
                            vm=vms_batch)
            if torch.cuda.device_count() > 1:
                loss = torch.mean(loss)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".
                      format(epoch, i + 1, total_loss / args.report_steps))
                epoch_loss += total_loss
                sys.stdout.flush()
                total_loss = 0.
            loss.backward()
            optimizer.step()

        all_loss.append(epoch_loss)
        epoch_loss = 0.
        print("Start evaluation on dev dataset.")
        result = evaluate(args, False)
        all_acc.append(result)
        if result > best_result:
            best_result = result
            save_model(model, args.output_model_path)
        else:
            continue

        print("Start evaluation on test dataset.")
        evaluate(args, True)

    # Evaluation phase.
    print("Final evaluation on the test dataset.")

    if torch.cuda.device_count() > 1:
        model.module.load_state_dict(torch.load(args.output_model_path))
    else:
        model.load_state_dict(torch.load(args.output_model_path))
    evaluate(args, True)

    print('all_loss:', all_loss)
    print('all_acc:', all_acc)
Пример #5
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    finetune_opts(parser)

    parser.add_argument("--label2id_path",
                        type=str,
                        required=True,
                        help="Path of the label2id file.")

    args = parser.parse_args()

    # Load the hyperparameters of the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    args.begin_ids = []

    with open(args.label2id_path, mode="r", encoding="utf-8") as f:
        l2i = json.load(f)
        print("Labels: ", l2i)
        l2i["[PAD]"] = len(l2i)
        for label in l2i:
            if label.startswith("B"):
                args.begin_ids.append(l2i[label])

    args.l2i = l2i

    args.labels_num = len(l2i)

    args.tokenizer = SpaceTokenizer(args)

    # Build sequence labeling model.
    model = NerTagger(args)

    # Load or initialize parameters.
    load_or_initialize_parameters(args, model)

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(args.device)

    # Training phase.
    instances = read_dataset(args, args.train_path)

    src = torch.LongTensor([ins[0] for ins in instances])
    tgt = torch.LongTensor([ins[1] for ins in instances])
    seg = torch.LongTensor([ins[2] for ins in instances])

    instances_num = src.size(0)
    batch_size = args.batch_size
    args.train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    print("Batch size: ", batch_size)
    print("The number of training instances:", instances_num)

    optimizer, scheduler = build_optimizer(args, model)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)
    args.model = model

    total_loss, f1, best_f1 = 0.0, 0.0, 0.0

    print("Start training.")

    for epoch in range(1, args.epochs_num + 1):
        model.train()
        for i, (src_batch, tgt_batch,
                seg_batch) in enumerate(batch_loader(batch_size, src, tgt,
                                                     seg)):
            loss = train(args, model, optimizer, scheduler, src_batch,
                         tgt_batch, seg_batch)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".
                      format(epoch, i + 1, total_loss / args.report_steps))
                total_loss = 0.0

        f1 = evaluate(args, read_dataset(args, args.dev_path))
        if f1 > best_f1:
            best_f1 = f1
            save_model(model, args.output_model_path)
        else:
            continue

    # Evaluation phase.
    if args.test_path is not None:
        print("Test set evaluation.")
        if torch.cuda.device_count() > 1:
            model.module.load_state_dict(torch.load(args.output_model_path))
        else:
            model.load_state_dict(torch.load(args.output_model_path))
        evaluate(args, read_dataset(args, args.test_path))
Пример #6
0
def worker(proc_id, gpu_ranks, args, model):
    """
    Args:
        proc_id: The id of GPU for single GPU mode;
                 The id of process (and GPU) for multiprocessing distributed mode.
        gpu_ranks: List of ranks of each process.
    """
    set_seed(args.seed)

    if args.dist_train:
        rank = gpu_ranks[proc_id]
        gpu_id = proc_id
    elif args.single_gpu:
        rank = None
        gpu_id = proc_id
    else:
        rank = None
        gpu_id = None

    if args.dist_train:
        train_loader = globals()[args.target.capitalize() + "DataLoader"](
            args, args.dataset_path, args.batch_size, rank, args.world_size,
            True)
    else:
        train_loader = globals()[args.target.capitalize() + "DataLoader"](
            args, args.dataset_path, args.batch_size, 0, 1, True)

    if gpu_id is not None:
        torch.cuda.set_device(gpu_id)
        model.cuda(gpu_id)

    # Build optimizer.
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      correct_bias=False)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=train_steps * args.warmup,
                                     t_total=train_steps)

    if args.dist_train:
        # Initialize multiprocessing distributed training environment.
        dist.init_process_group(backend=args.backend,
                                init_method=args.master_ip,
                                world_size=args.world_size,
                                rank=rank)
        model = DistributedDataParallel(model, device_ids=[gpu_id])
        print("Worker %d is training ... " % rank)
    else:
        print("Worker is training ...")

    globals().get("train_" + args.target)(args, gpu_id, rank, train_loader,
                                          model, optimizer, scheduler)
Пример #7
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    finetune_opts(parser)

    parser.add_argument("--pooling",
                        choices=["mean", "max", "first", "last"],
                        default="first",
                        help="Pooling type.")

    tokenizer_opts(parser)

    parser.add_argument("--soft_targets",
                        action='store_true',
                        help="Train model with logits.")
    parser.add_argument("--soft_alpha",
                        type=float,
                        default=0.5,
                        help="Weight of the soft targets loss.")

    adv_opts(parser)

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    # Count the number of labels.
    args.labels_num = count_labels_num(args.train_path)

    # Build tokenizer.
    args.tokenizer = str2tokenizer[args.tokenizer](args)

    # Build classification model.
    model = Classifier(args)

    # Load or initialize parameters.
    load_or_initialize_parameters(args, model)

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(args.device)

    # Training phase.
    trainset = read_dataset(args, args.train_path)
    random.shuffle(trainset)
    instances_num = len(trainset)
    batch_size = args.batch_size

    src = torch.LongTensor([example[0] for example in trainset])
    tgt = torch.LongTensor([example[1] for example in trainset])
    seg = torch.LongTensor([example[2] for example in trainset])

    args.train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    print("Batch size: ", batch_size)
    print("The number of training instances:", instances_num)

    optimizer, scheduler = build_optimizer(args, model)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)
        args.amp = amp

    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)
    args.model = model

    if args.use_adv:
        args.adv_method = str2adv[args.adv_type](model)

    total_loss, result, best_result = 0.0, 0.0, 0.0

    print("Start training.")

    for epoch in range(1, args.epochs_num + 1):
        model.train()
        for i, (src_batch, tgt_batch, seg_batch,
                _) in enumerate(batch_loader(batch_size, src, tgt, seg)):
            loss = train_model(args, model, optimizer, scheduler, src_batch,
                               tgt_batch, seg_batch)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".
                      format(epoch, i + 1, total_loss / args.report_steps))
                total_loss = 0.0

        result = evaluate(args, read_dataset(args, args.dev_path))
        if result > best_result:
            best_result = result
            save_model(model, args.output_model_path)

    # Evaluation phase.
    if args.test_path is not None:
        print("Test set evaluation.")
        if torch.cuda.device_count() > 1:
            args.model.module.load_state_dict(
                torch.load(args.output_model_path))
        else:
            args.model.load_state_dict(torch.load(args.output_model_path))
        evaluate(args, read_dataset(args, args.test_path))
Пример #8
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path",
                        default=None,
                        type=str,
                        required=True,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path",
                        default="./models/tagger_model.bin",
                        type=str,
                        help="Path of the output model.")
    parser.add_argument("--output_encoder",
                        default="./luke-models/",
                        type=str,
                        help="Path of the output luke model.")
    parser.add_argument("--suffix_file_encoder",
                        default="encoder",
                        type=str,
                        help="output file suffix luke model.")
    parser.add_argument("--vocab_path",
                        default="./models/google_vocab.txt",
                        type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path",
                        type=str,
                        required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path",
                        type=str,
                        required=True,
                        help="Path of the devset.")
    parser.add_argument("--test_path",
                        type=str,
                        required=True,
                        help="Path of the testset.")
    parser.add_argument("--config_path",
                        default="./models/google_config.json",
                        type=str,
                        help="Path of the config file.")
    parser.add_argument("--output_file_prefix",
                        type=str,
                        required=True,
                        help="Prefix for file output.")
    parser.add_argument("--log_file", default='app.log')

    # Model options.
    parser.add_argument("--seq_length",
                        default=256,
                        type=int,
                        help="Sequence length.")
    parser.add_argument("--classifier",
                        choices=["mlp", "lstm", "lstm_crf", "lstm_ncrf"],
                        default="mlp",
                        help="Classifier type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")
    parser.add_argument('--freeze_encoder_weights',
                        action='store_true',
                        help="Enable to freeze the encoder weigths.")

    # Subword options.
    parser.add_argument("--subword_type",
                        choices=["none", "char"],
                        default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path",
                        type=str,
                        default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder",
                        choices=["avg", "lstm", "gru", "cnn"],
                        default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num",
                        type=int,
                        default=2,
                        help="The number of subencoder layers.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout.")
    parser.add_argument("--epochs_num",
                        type=int,
                        default=0,
                        help="Number of epochs.")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=2,
                        help="Number of steps to accumulate the gradient.")
    parser.add_argument("--report_steps",
                        type=int,
                        default=200,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=35, help="Random seed.")
    parser.add_argument("--batch_size",
                        type=int,
                        default=32,
                        help="Batch_size.")
    parser.add_argument("--num_train_steps",
                        type=int,
                        default=20000,
                        help="Max steps to be trained.")
    parser.add_argument("--patience",
                        type=int,
                        default=8000,
                        help="Specific steps to wait until stops training.")

    # Optimizer options.
    parser.add_argument("--learning_rate", default=1e-5, type=float)
    parser.add_argument("--lr_schedule",
                        default="warmup_linear",
                        type=str,
                        choices=["warmup_linear", "warmup_constant"])
    parser.add_argument("--weight_decay", default=0.01, type=float)
    parser.add_argument("--max_grad_norm", default=0.0, type=float)
    parser.add_argument("--adam_b1", default=0.9, type=float)
    parser.add_argument("--adam_b2", default=0.999, type=float)
    parser.add_argument("--adam_eps", default=1e-8, type=float)
    parser.add_argument("--adam_correct_bias", action='store_true')
    parser.add_argument("--warmup_proportion", default=0.006, type=float)
    parser.add_argument("--freeze_proportions", default=0.0, type=float)

    # kg
    parser.add_argument("--kg_name", required=True, help="KG name or path")
    parser.add_argument("--use_kg",
                        action='store_true',
                        help="Enable the use of KG.")
    parser.add_argument("--padding",
                        action='store_true',
                        help="Enable padding.")
    parser.add_argument("--dry_run",
                        action='store_true',
                        help="Dry run to test the implementation.")
    parser.add_argument(
        "--voting_choicer",
        action='store_true',
        help="Enable the Voting choicer to select the entity type.")
    parser.add_argument("--eval_kg_tag",
                        action='store_true',
                        help="Enable to include [ENT] tag in evaluation.")
    parser.add_argument("--use_subword_tag",
                        action='store_true',
                        help="Enable to use separate tag for subword splits.")
    parser.add_argument("--debug", action='store_true', help="Enable debug.")
    parser.add_argument("--reverse_order",
                        action='store_true',
                        help="Reverse the feature selection order.")
    parser.add_argument("--max_entities",
                        default=2,
                        type=int,
                        help="Number of KG features.")
    parser.add_argument("--eval_range_with_types",
                        action='store_true',
                        help="Enable to eval range with types.")

    args = parser.parse_args()

    # Load the hyperparameters of the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    logging.basicConfig(filename=args.log_file, filemode='w', format=fmt)

    labels_map = {"[PAD]": 0, "[ENT]": 1, "[X]": 2, "[CLS]": 3, "[SEP]": 4}
    begin_ids = []

    # Find tagging labels
    for file in (args.train_path, args.dev_path, args.test_path):
        with open(file, mode="r", encoding="utf-8") as f:
            for line_id, line in enumerate(f):
                if line_id == 0:
                    continue
                labels = line.strip().split("\t")[0].split()
                for l in labels:
                    if l not in labels_map:
                        if l.startswith("B") or l.startswith("S"):
                            begin_ids.append(len(labels_map))
                            # check if I-TAG exists
                            infix = l[1]
                            tag = l[2:]
                            inner_tag = f'I{infix}{tag}'
                            if inner_tag not in labels_map:
                                labels_map[inner_tag] = len(labels_map)

                        labels_map[l] = len(labels_map)

    idx_to_label = {labels_map[key]: key for key in labels_map}

    print(begin_ids)
    print("Labels: ", labels_map)
    args.labels_num = len(labels_map)

    # Build knowledge graph.
    if args.kg_name == 'none':
        kg_file = []
    else:
        kg_file = args.kg_name

    # Load Luke model.
    model_archive = ModelArchive.load(args.pretrained_model_path)
    tokenizer = model_archive.tokenizer

    # Handling space character in roberta tokenizer
    byte_encoder = bytes_to_unicode()
    byte_decoder = {v: k for k, v in byte_encoder.items()}

    # Load the pretrained model
    encoder = LukeModel(model_archive.config)
    encoder.load_state_dict(model_archive.state_dict, strict=False)

    kg = KnowledgeGraph(kg_file=kg_file, tokenizer=tokenizer)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    # Build sequence labeling model.
    classifiers = {
        "mlp": LukeTaggerMLP,
        "lstm": LukeTaggerLSTM,
        "lstm_crf": LukeTaggerLSTMCRF,
        "lstm_ncrf": LukeTaggerLSTMNCRF
    }
    logger.info(f'The selected classifier is:{classifiers[args.classifier]}')
    model = classifiers[args.classifier](args, encoder)
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = nn.DataParallel(model)
    model = model.to(device)

    # Read dataset.
    def read_dataset(path):
        dataset = []
        count = 0
        with open(path, mode="r", encoding="utf8") as f:
            f.readline()
            tokens, labels = [], []
            for line_id, line in enumerate(f):
                fields = line.strip().split("\t")
                if len(fields) == 2:
                    labels, tokens = fields
                elif len(fields) == 3:
                    labels, tokens, cls = fields
                else:
                    print(
                        f'The data is not in accepted format at line no:{line_id}.. Ignored'
                    )
                    continue

                tokens, pos, vm, tag = \
                    kg.add_knowledge_with_vm([tokens], [labels],
                                             use_kg=args.use_kg,
                                             max_length=args.seq_length,
                                             max_entities=args.max_entities,
                                             reverse_order=args.reverse_order,
                                             padding=args.padding)
                tokens = tokens[0]
                pos = pos[0]
                vm = vm[0].astype("bool")
                tag = tag[0]

                # tokens = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokens + [tokenizer.sep_token])
                non_pad_tokens = [
                    tok for tok in tokens if tok != tokenizer.pad_token
                ]
                num_tokens = len(non_pad_tokens)
                num_pad = len(tokens) - num_tokens

                labels = [config.CLS_TOKEN
                          ] + labels.split(" ") + [config.SEP_TOKEN]
                new_labels = []
                j = 0
                joiner = '-'
                for i in range(len(tokens)):
                    if tag[i] == 0 and tokens[i] != tokenizer.pad_token:
                        cur_type = labels[j]
                        if cur_type != 'O':
                            try:
                                joiner = cur_type[1]
                                prev_label = cur_type[2:]
                            except:
                                logger.info(
                                    f'The label:{cur_type} is converted to O')
                                prev_label = 'O'
                                j += 1
                                new_labels.append('O')
                                continue
                        else:
                            prev_label = cur_type

                        new_labels.append(cur_type)
                        j += 1

                    elif tag[i] == 1 and tokens[
                            i] != tokenizer.pad_token:  # 是添加的实体
                        new_labels.append('[ENT]')
                    elif tag[i] == 2:
                        if prev_label == 'O':
                            new_labels.append('O')
                        else:
                            if args.use_subword_tag:
                                new_labels.append('[X]')
                            else:
                                new_labels.append(f'I{joiner}' + prev_label)
                    else:
                        new_labels.append(PAD_TOKEN)

                new_labels = [labels_map[l] for l in new_labels]

                # print(tokens)
                # print(labels)
                # print(tag)
                if num_pad != 0:
                    print(num_pad)
                    exit()
                mask = [1] * (num_tokens) + [0] * num_pad
                word_segment_ids = [0] * (len(tokens))

                # print(len(tokens))
                # print(len(tag))
                # exit()
                # print(tokenizer.pad_token_id)

                # for i in range(len(tokens)):
                #     if tag[i] == 0 and tokens[i] != tokenizer.pad_token:
                #         new_labels.append(labels[j])
                #         j += 1
                #     elif tag[i] == 1 and tokens[i] != tokenizer.pad_token:  # 是添加的实体
                #         new_labels.append(labels_map['[ENT]'])
                #     elif tag[i] == 2:
                #         if args.use_subword_tag:
                #             new_labels.append(labels_map['[X]'])
                #         else:
                #             new_labels.append(labels_map['[ENT]'])
                #     else:
                #         new_labels.append(labels_map[PAD_TOKEN])

                # print(labels)
                # print(new_labels)
                # print([idx_to_label.get(key) for key in labels])
                # print([idx_to_label.get(key) for key in labels])
                # print(mask)
                # print(pos)
                # print(word_segment_ids)
                # print(tokens)
                # tokens = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokens + [tokenizer.sep_token])
                tokens = tokenizer.convert_tokens_to_ids(tokens)
                # print(tokens)
                # exit()
                assert len(tokens) == len(new_labels), AssertionError(
                    "The length of token and label is not matching")

                dataset.append(
                    [tokens, new_labels, mask, pos, vm, tag, word_segment_ids])

                # Enable dry rune
                if args.dry_run:
                    count += 1
                    if count == 100:
                        break

        return dataset

    # Evaluation function.
    def evaluate(args, is_test, final=False):
        if is_test:
            dataset = read_dataset(args.test_path)
        else:
            dataset = read_dataset(args.dev_path)

        instances_num = len(dataset)
        batch_size = args.batch_size

        if is_test:
            logger.info(f"Batch size:{batch_size}")
            print(f"The number of test instances:{instances_num}")

        true_labels_all = []
        predicted_labels_all = []
        confusion = torch.zeros(len(labels_map),
                                len(labels_map),
                                dtype=torch.long)
        model.eval()

        test_batcher = Batcher(batch_size,
                               dataset,
                               shuffle=False,
                               token_pad=tokenizer.pad_token_id,
                               label_pad=labels_map[PAD_TOKEN])

        for i, (input_ids_batch, label_ids_batch, mask_ids_batch,
                pos_ids_batch, vm_ids_batch,
                segment_ids_batch) in enumerate(test_batcher):

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            pos_ids_batch = pos_ids_batch.to(device)
            vm_ids_batch = vm_ids_batch.long().to(device)
            segment_ids_batch = segment_ids_batch.long().to(device)

            pred = model(input_ids_batch,
                         segment_ids_batch,
                         mask_ids_batch,
                         label_ids_batch,
                         pos_ids_batch,
                         vm_ids_batch,
                         use_kg=args.use_kg)

            for pred_sample, gold_sample, mask in zip(pred, label_ids_batch,
                                                      mask_ids_batch):

                pred_labels = [
                    idx_to_label.get(key) for key in pred_sample.tolist()
                ]
                gold_labels = [
                    idx_to_label.get(key) for key in gold_sample.tolist()
                ]

                num_labels = sum(mask)

                # Exclude the [CLS], and [SEP] tokens
                pred_labels = pred_labels[1:num_labels - 1]
                true_labels = gold_labels[1:num_labels - 1]

                pred_labels = [p.replace('_NOKG', '') for p in pred_labels]
                true_labels = [t.replace('_NOKG', '') for t in true_labels]

                true_labels, pred_labels = filter_kg_labels(
                    true_labels, pred_labels)

                pred_labels = [p.replace('_', '-') for p in pred_labels]
                true_labels = [t.replace('_', '-') for t in true_labels]

                biluo_tags_predicted = get_bio(pred_labels)
                biluo_tags_true = get_bio(true_labels)

                if len(biluo_tags_predicted) != len(biluo_tags_true):
                    logger.error(
                        'The length of the predicted labels is not same as that of true labels..'
                    )
                    exit()

                predicted_labels_all.append(biluo_tags_predicted)
                true_labels_all.append(biluo_tags_true)

        if final:
            with open(f'{args.output_file_prefix}_predictions.txt', 'a') as p, \
                    open(f'{args.output_file_prefix}_gold.txt', 'a') as g:
                p.write('\n'.join([' '.join(l) for l in predicted_labels_all]))
                g.write('\n'.join([' '.join(l) for l in true_labels_all]))

        return dict(
            f1=seqeval.metrics.f1_score(true_labels_all, predicted_labels_all),
            precision=seqeval.metrics.precision_score(true_labels_all,
                                                      predicted_labels_all),
            recall=seqeval.metrics.recall_score(true_labels_all,
                                                predicted_labels_all),
            f1_span=f1_score_span(true_labels_all, predicted_labels_all),
            precision_span=precision_score_span(true_labels_all,
                                                predicted_labels_all),
            recall_span=recall_score_span(true_labels_all,
                                          predicted_labels_all),
        )

    def maybe_no_sync(step):
        if (hasattr(model, "no_sync")
                and (step + 1) % args.gradient_accumulation_steps != 0):
            return model.no_sync()
        else:
            return contextlib.ExitStack()

    def train():
        wandb.init()
        # update learning rate
        args.learning_rate = wandb.config.learning_rate

        # Update batch-size
        args.batch_size = wandb.config.batch_size

        # Update lr-scheduler
        args.lr_schedule = wandb.config.lr_schedule

        # Update weight decay
        args.weight_decay = wandb.config.weight_decay

        # Update max-grad norm
        args.max_grad_norm = wandb.config.max_grad_norm

        # Training phase.
        logger.info("Start training.")
        instances = read_dataset(args.train_path)

        instances_num = len(instances)
        batch_size = args.batch_size

        if args.epochs_num:
            args.num_train_steps = int(
                instances_num * args.epochs_num / batch_size) + 1

        unfreeze_steps = 0
        model_frozen = False
        if args.freeze_proportions != 0.0:
            unfreeze_steps = int(
                args.num_train_steps * args.freeze_proportions) + 1
            logger.info(
                f'Two phase training is enabled with model unfreeze at:{unfreeze_steps}'
            )
            # freeze the model
            model.freeze()
            model_frozen = True

        logger.info(f"Batch size:{batch_size}")
        logger.info(f"The number of training instances:{instances_num}")

        train_batcher = Batcher(batch_size,
                                instances,
                                shuffle=True,
                                token_pad=tokenizer.pad_token_id,
                                label_pad=labels_map[PAD_TOKEN])

        optimizer = create_optimizer(args, model)
        scheduler = create_scheduler(args, optimizer)
        total_loss = 0.
        best_f1 = 0.0

        # Dry evaluate
        # evaluate(args, True)

        global_steps = 0
        early_stop_steps = 0
        epoch = 0

        with tqdm(total=args.num_train_steps) as pbar:
            while True:
                model.train()
                for step, (input_ids_batch, label_ids_batch, mask_ids_batch,
                           pos_ids_batch, vm_ids_batch,
                           segment_ids_batch) in enumerate(train_batcher):

                    input_ids_batch = input_ids_batch.to(device)
                    label_ids_batch = label_ids_batch.to(device)
                    mask_ids_batch = mask_ids_batch.to(device)
                    pos_ids_batch = pos_ids_batch.to(device)
                    vm_ids_batch = vm_ids_batch.long().to(device)
                    segment_ids_batch = segment_ids_batch.long().to(device)

                    loss = model.score(input_ids_batch,
                                       segment_ids_batch,
                                       mask_ids_batch,
                                       label_ids_batch,
                                       pos_ids_batch,
                                       vm_ids_batch,
                                       use_kg=args.use_kg)

                    if torch.cuda.device_count() > 1:
                        loss = torch.mean(loss)

                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    with maybe_no_sync(step):
                        loss.backward()

                    total_loss += loss.item()
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        if args.max_grad_norm != 0.0:
                            torch.nn.utils.clip_grad_norm_(
                                model.parameters(), args.max_grad_norm)
                        optimizer.step()
                        scheduler.step()
                        model.zero_grad()
                        optimizer.zero_grad()

                        pbar.set_description("epoch: %d loss: %.7f" %
                                             (epoch, loss.item()))
                        pbar.update()
                        global_steps += 1

                        if global_steps % args.report_steps == 0:
                            logger.info(
                                "Epoch id: {}, Global Steps:{}, Avg loss: "
                                "{:.10f}".format(
                                    epoch, global_steps + 1,
                                    total_loss / args.report_steps))

                            # Evaluation phase.
                            logger.info("Start evaluate on dev dataset.")
                            results = evaluate(args, False)
                            logger.info(results)

                            logger.info("Start evaluation on test dataset.")
                            results_test = evaluate(args, True)
                            logger.info(results_test)

                            if results['f1'] > best_f1:
                                best_f1 = results['f1']
                                early_stop_steps = 0
                                save_model(model, args.output_model_path)
                                save_encoder(args,
                                             encoder,
                                             suffix=args.suffix_file_encoder)
                            else:
                                early_stop_steps += args.report_steps

                            avg_loss = total_loss / args.report_stepsloss
                            # Log the loss and accuracy values at the end of each epoch
                            wandb.log({
                                "steps":
                                global_steps,
                                "train Loss":
                                avg_loss,
                                "valid_acc":
                                results['f1'],
                                "learning_rate":
                                wandb.config.learning_rate,
                                "batch_size":
                                wandb.config.batch_size,
                                "lr_schedule":
                                wandb.config.lr_schedule,
                                "weight_decay":
                                wandb.config.weight_decay,
                                "max_grad_norm":
                                wandb.config.max_grad_norm,
                            })

                            total_loss = 0.

                            # Change back the model for training
                            model.train()

                    if model_frozen and global_steps >= unfreeze_steps:
                        # unfreeze the model and start training
                        model.unfreeze()
                        model_frozen = False

                    if global_steps >= args.num_train_steps:
                        # Training completed
                        break

                    if early_stop_steps >= args.patience:
                        # Early stopping
                        break

                if model_frozen and global_steps >= unfreeze_steps:
                    # unfreeze the model and start training
                    model.unfreeze()
                    model_frozen = False

                if global_steps >= args.num_train_steps:
                    # Training completed
                    break

                if early_stop_steps >= args.patience:
                    # Early stopping
                    break

                epoch += 1

            # Evaluation phase.
            if torch.cuda.device_count() > 1:
                model.module.load_state_dict(torch.load(
                    args.output_model_path))
            else:
                model.load_state_dict(torch.load(args.output_model_path))
            results_final = evaluate(args, True, final=True)
            logger.info(results_final)

    # WandB Configurations (optional)
    sweep_config = {
        'method': 'grid',  # grid, random
        'metric': {
            'name': 'valid_acc',
            'goal': 'maximize'
        },
        'parameters': {
            'learning_rate': {
                'values': [1e-5, 2e-5, 4e-5, 6e-5]
            },
            'batch_size': {
                'values': [16, 32, 64, 128]
            },
            'lr_schedule': {
                'values': ["warmup_linear", "warmup_constant"]
            },
            'weight_decay': {
                'values': [0.01, 0.02, 0.03]
            },
            'max_grad_norm': {
                'values': [1.0, 0.0]
            }
        }
    }

    sweep_id = wandb.sweep(sweep_config, project="kbert_nlu")
    # Call the wandb agent
    wandb.agent(sweep_id, function=lambda: train())
Пример #9
0
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path", default=None, type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path", default="./models/QA_model.bin", type=str,
                        help="Path of the output model.")
    parser.add_argument("--vocab_path", default="./models/google_vocab.txt", type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path", type=str, required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path", type=str, required=True,
                        help="Path of the devset.") 
    parser.add_argument("--test_path", type=str,
                        help="Path of the testset.")
    parser.add_argument("--config_path", default="./models/google_config.json", type=str,
                        help="Path of the config file.")

    # Model options.
    parser.add_argument("--batch_size", type=int, default=64,
                        help="Batch size.")
    parser.add_argument("--seq_length", type=int, default=100,
                        help="Sequence length.")
    parser.add_argument("--doc_stride", default=128, type=int,
                        help="When splitting up a long document into chunks, how much stride to take between chunks.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                                   "cnn", "gatedcnn", "attn", \
                                                   "rcnn", "crnn", "gpt"], \
                                                   default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional", action="store_true", help="Specific to recurrent model.")


    # Subword options.
    parser.add_argument("--subword_type", choices=["none", "char"], default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path", type=str, default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder", choices=["avg", "lstm", "gru", "cnn"], default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num", type=int, default=2, help="The number of subencoder layers.")

    # Tokenizer options.
    parser.add_argument("--tokenizer", choices=["bert", "char", "word", "space"], default="char",
                        help="Specify the tokenizer." 
                             "Original Google BERT uses bert tokenizer on Chinese corpus."
                             "Char tokenizer segments sentences into characters."
                             "Word tokenizer supports online word segmentation based on jieba segmentor."
                             "Space tokenizer segments sentences into words according to space."
                             )

    # Optimizer options.
    parser.add_argument("--learning_rate", type=float, default=3e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup", type=float, default=0.1,
                        help="Warm up value.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.5,
                        help="Dropout.")
    parser.add_argument("--epochs_num", type=int, default=3,
                        help="Number of epochs.")
    parser.add_argument("--report_steps", type=int, default=100,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7,
                        help="Random seed.")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)
      
    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    args.target = "bert"
    bert_model = build_model(args)
    # Load or initialize parameters.
    if args.pretrained_model_path is not None:
        # Initialize with pretrained model. 
        bert_model.load_state_dict(torch.load(args.pretrained_model_path), strict=False)  
    else:
        # Initialize with normal distribution.
        for n, p in list(bert_model.named_parameters()):
            if 'gamma' not in n and 'beta' not in n:
                p.data.normal_(0, 0.02)
    
    # Build QA model.
    model = BertQuestionAnswering(args,bert_model)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)

    model = model.to(device)

    # Dataset loader.
    def batch_loader(batch_size, input_ids, mask_ids, start_positions, end_positions):
        instances_num = input_ids.size()[0]
        for i in range(instances_num // batch_size):
            input_ids_batch = input_ids[i*batch_size: (i+1)*batch_size, :]
            mask_ids_batch = mask_ids[i*batch_size: (i+1)*batch_size, :]
            start_positions_batch = start_positions[i*batch_size: (i+1)*batch_size]
            end_positions_batch = end_positions[i*batch_size: (i+1)*batch_size]
            yield input_ids_batch, mask_ids_batch, start_positions_batch, end_positions_batch
        if instances_num > instances_num // batch_size * batch_size:
            input_ids_batch = input_ids[instances_num//batch_size*batch_size:, :]
            mask_ids_batch = mask_ids[instances_num//batch_size*batch_size:, :]
            start_positions_batch = start_positions[instances_num//batch_size*batch_size:]
            end_positions_batch = end_positions[instances_num//batch_size*batch_size:]
            yield input_ids_batch, mask_ids_batch, start_positions_batch, end_positions_batch

    # Build tokenizer.
    tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"](args)

    # Read examples.
    def read_examples(path):
        examples = []
        with open(path,'r',encoding='utf-8') as fp:
            all_dict = json.loads(fp.read())
            v1 = all_dict["data"]
            for i in range(len(v1)):
                data_dict = v1[i]
                v2 = data_dict["paragraphs"]

                for j in range(len(v2)):
                    para_dict = v2[j]
                    context = para_dict["context"]
                    v3 = para_dict["qas"]

                    for m in range(len(v3)):
                        qas_dict = v3[m]
                        question = qas_dict["question"]                                          
                        question_id = qas_dict["id"]
                        v4 = qas_dict["answers"]
                        
                        answers=[]
                        start_positions=[]
                        end_positions=[]

                        for n in range(len(v4)):
                            ans_dict = v4[n]
                            answer = ans_dict["text"]
                            start_position = ans_dict["answer_start"]
                            end_position = start_position + len(answer)
                            
                            answers.append(answer)
                            start_positions.append(start_position)
                            end_positions.append(end_position)

                        examples.append((context,question,question_id,start_positions,end_positions,answers))
        
        return examples


    def convert_examples_to_dataset(examples, args):
        dataset = []
        print("The number of questions in the dataset",len(examples))
        for i in range(len(examples)):
            context = examples[i][0]
            question = examples[i][1]
            q_len = len(question)
            question_id = examples[i][2]

            start_positions_true = examples[i][3][0]#待修改
            end_positions_true = examples[i][4][0]
            
            answers = examples[i][5]
            max_context_length = args.seq_length - q_len - 3
            # divide the context to some spans
            _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
                "DocSpan", ["start", "length"])
            doc_spans = []
            start_offset = 0
            while start_offset < len(context):
                length = len(context) - start_offset
                if length > max_context_length:
                    length = max_context_length
                doc_spans.append(_DocSpan(start=start_offset, length=length))
                if start_offset + length == len(context):
                    break
                start_offset += min(length, args.doc_stride)

            for (doc_span_index, doc_span) in enumerate(doc_spans):
                doc_span_start=doc_span.start
                span_context = context[doc_span_start:doc_span_start+doc_span.length]         
                # convert the start or end position to real position in tokens
                start_positions = start_positions_true - doc_span_start + q_len + 2
                end_positions = end_positions_true - doc_span_start + q_len + 2
                # the answers of some question are not in the doc_span, we ignore them.
                if start_positions < q_len+2 or start_positions > doc_span.length+q_len+2 or end_positions < q_len+2 or end_positions > doc_span.length+q_len+2:
                    continue 

                tokens_a = [vocab.get(t) for t in tokenizer.tokenize(question)]
                tokens_a = [CLS_ID] + tokens_a + [SEP_ID]
                tokens_b = [vocab.get(t) for t in tokenizer.tokenize(span_context)]
                tokens_b = tokens_b + [SEP_ID] 
                tokens = tokens_a + tokens_b
                mask = [1] * len(tokens_a) + [2] * len(tokens_b)

                while len(tokens) < args.seq_length:
                    tokens.append(0)
                    mask.append(0)

                dataset.append((tokens,mask,start_positions,end_positions,answers,question_id,q_len,doc_span_index,doc_span_start))       
        return dataset


    # Evaluation function.
    def evaluate(args, is_test):
        # some calculation functions
        def mixed_segmentation(in_str, rm_punc=False):
            in_str = str(in_str).lower().strip()
            segs_out = []
            temp_str = ""
            sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
                   ',','。',':','?','!','“','”',';','’','《','》','……','·','、',
                   '「','」','(',')','-','~','『','』']
            for char in in_str:
                if rm_punc and char in sp_char:
                    continue
                if  re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char:
                    if temp_str != "":
                        ss = nltk.word_tokenize(temp_str)
                        segs_out.extend(ss)
                        temp_str = ""
                    segs_out.append(char)
                else:
                    temp_str += char

            #handling last part
            if temp_str != "":
                ss = nltk.word_tokenize(temp_str)
                segs_out.extend(ss)

            return segs_out


        # remove punctuation
        def remove_punctuation(in_str):
            in_str = str(in_str).lower().strip()
            sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
                   ',','。',':','?','!','“','”',';','’','《','》','……','·','、',
                   '「','」','(',')','-','~','『','』']
            out_segs = []
            for char in in_str:
                if char in sp_char:
                    continue
                else:
                    out_segs.append(char)
            return ''.join(out_segs)


        # find longest common string
        def find_lcs(s1, s2):
            m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)]
            mmax = 0
            p = 0
            for i in range(len(s1)):
                for j in range(len(s2)):
                    if s1[i] == s2[j]:
                        m[i+1][j+1] = m[i][j]+1
                        if m[i+1][j+1] > mmax:
                            mmax=m[i+1][j+1]
                            p=i+1
            return s1[p-mmax:p], mmax

        def calc_f1_score(answers, prediction):
            f1_scores = []     
            for i in range(len(answers)):
                ans = answers[i]
                ans_segs = mixed_segmentation(ans, rm_punc=True)
                prediction_segs = mixed_segmentation(prediction, rm_punc=True)
                lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
                if lcs_len == 0:
                    f1_scores.append(0)
                else:
                    precision   = 1.0*lcs_len/len(prediction_segs)
                    recall      = 1.0*lcs_len/len(ans_segs)
                    f1          = (2*precision*recall)/(precision+recall)
                    f1_scores.append(f1)
            return max(f1_scores)


        def calc_em_score(answers, prediction):
            em = 0
            for i in range(len(answers)):
                ans = answers[i]
                ans_ = remove_punctuation(ans)
                prediction_ = remove_punctuation(prediction)
                if ans_ == prediction_:
                    em = 1
                    break
            return em

        def is_max_score(score_list):
            score_max = -100
            index_max = 0
            best_start_prediction = 0
            best_end_prediction = 0
            for i in range(len(score_list)): 
                if score_max <= score_list[i][3]:
                    score_max = score_list[i][3]
                    index_max = score_list[i][0]
                    best_start_prediction = score_list[i][1]
                    best_end_prediction = score_list[i][2]
            return index_max, best_start_prediction,best_end_prediction

        if is_test:
            examples = read_examples(args.test_path)
            dataset = convert_examples_to_dataset(examples,args)

        else:
            examples = read_examples(args.dev_path)
            dataset = convert_examples_to_dataset(examples,args)
        
        input_ids = torch.LongTensor([sample[0] for sample in dataset])
        mask_ids = torch.LongTensor([sample[1] for sample in dataset])
        start_positions = torch.LongTensor([sample[2] for sample in dataset])
        end_positions = torch.LongTensor([sample[3] for sample in dataset])
        
        batch_size = args.batch_size
        instances_num = input_ids.size()[0]
        
        if is_test:
            print("The number of evaluation instances: ", instances_num)
        model.eval()
        start_logits_all = []
        end_logits_all = []
        start_pred_all = []
        end_pred_all = []
        for i, (input_ids_batch, mask_ids_batch, start_positions_batch, end_positions_batch) in enumerate(batch_loader(batch_size, input_ids, mask_ids, start_positions, end_positions)):
            model.zero_grad()
            input_ids_batch = input_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            start_positions_batch = start_positions_batch.to(device)
            end_positions_batch = end_positions_batch.to(device)
                
            with torch.no_grad():
                loss, start_logits, end_logits = model(input_ids_batch, mask_ids_batch, start_positions_batch, end_positions_batch)
                
            start_logits = nn.Softmax(dim=1)(start_logits)
            end_logits = nn.Softmax(dim=1)(end_logits)

            start_pred = torch.argmax(start_logits, dim=1)
            end_pred = torch.argmax(end_logits, dim=1)

            start_pred=start_pred.cpu().numpy().tolist()
            end_pred=end_pred.cpu().numpy().tolist()
            
            start_logits=start_logits.cpu().numpy().tolist()
            end_logits=end_logits.cpu().numpy().tolist()

            start_logits_max=[]
            end_logits_max=[]
            for j in range(len(start_pred)):
                start_logits_max.append(start_logits[j][start_pred[j]])
                end_logits_max.append(end_logits[j][end_pred[j]])
            
            start_logits_all += start_logits_max
            end_logits_all += end_logits_max
            start_pred_all += start_pred
            end_pred_all  += end_pred
        
        assert len(start_pred_all)==len(dataset)
        assert len(start_logits_all)==len(dataset)        

        # couster by question id and chose the best answer in doc_spans
        order = -1
        pred_list = []
        templist=[]
        for i in range(len(dataset)):
            qid = dataset[i][5]
            q_len = dataset[i][6]
            span_index =dataset[i][7]
            doc_span_start = dataset[i][8]

            score1 = float(start_logits_all[i])
            score2 = float(end_logits_all[i])
            score = (score1+score2)/2
            
            pre_start_pred = start_pred_all[i] + doc_span_start - q_len - 2
            pre_end_pred = end_pred_all[i] + doc_span_start - q_len - 2
            
            if qid == order:
                templist.append((span_index,pre_start_pred,pre_end_pred,score))
            else:
                order = qid
                if i > 0:
                    span_index_max, best_start_prediction,best_end_prediction = is_max_score(templist)   
                    pred_list.append((span_index_max, best_start_prediction,best_end_prediction))
                templist = []
                templist.append((span_index,pre_start_pred,pre_end_pred,score))
        span_index_max, best_start_prediction, best_end_prediction = is_max_score(templist)   
        pred_list.append((span_index_max, best_start_prediction,best_end_prediction))
   
        assert len(pred_list) == len(examples)

        #strat pred
        f1 = 0
        em = 0
        total_count = len(examples)
        skip_count = 0
        for i in range(len(examples)):
            question_id = examples[i][2]
            answers = examples[i][5]
            span_index = pred_list[i][0]
            start_prediction = pred_list[i][1]
            end_prediction = pred_list[i][2]
            
            #error prediction
            if end_prediction <= start_prediction:
                skip_count += 1
                continue
                
            prediction = examples[i][0][start_prediction:end_prediction]
            
            f1 += calc_f1_score(answers, prediction)
            em += calc_em_score(answers, prediction)
        
        f1_score = 100.0 * f1 / total_count
        em_score = 100.0 * em / total_count
        avg = (f1_score+em_score)*0.5
        print("Avg: {:.4f},F1:{:.4f},EM:{:.4f},Total:{},Skip:{}".format(avg,f1_score,em_score,total_count,skip_count))
        return avg 

    # Training phase
    print("Start training.")
    batch_size = args.batch_size
    print("Batch size: ", batch_size)
    examples = read_examples(args.train_path)
    trainset = convert_examples_to_dataset(examples,args)
    random.shuffle(trainset)
    instances_num = len(trainset)

    input_ids = torch.LongTensor([sample[0] for sample in trainset])
    mask_ids = torch.LongTensor([sample[1] for sample in trainset])
    start_positions = torch.LongTensor([sample[2] for sample in trainset])
    end_positions = torch.LongTensor([sample[3] for sample in trainset])

    train_steps = int(instances_num * args.epochs_num / batch_size) + 1
   
    print("The number of training instances:", instances_num)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
    ]
    optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup, t_total=train_steps)

    total_loss = 0.
    result = 0.0
    best_result = 0.0
    
    for epoch in range(1, args.epochs_num+1):
        model.train()
        
        for i, (input_ids_batch, mask_ids_batch, start_positions_batch, end_positions_batch) in enumerate(batch_loader(batch_size, input_ids, mask_ids, start_positions, end_positions)):
            model.zero_grad()
            input_ids_batch = input_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            start_positions_batch = start_positions_batch.to(device)
            end_positions_batch = end_positions_batch.to(device)

            loss, _, _ = model(input_ids_batch, mask_ids_batch, start_positions_batch, end_positions_batch)
            if torch.cuda.device_count() > 1:
                loss = torch.mean(loss)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".format(epoch, i+1, total_loss / args.report_steps))
                total_loss = 0.
            loss.backward()
            optimizer.step()
        result = evaluate(args, False)
        if result > best_result:
            best_result = result
            save_model(model, args.output_model_path)
        else:
            break

    # Evaluation phase.
    if args.test_path:
        print("Start evaluation.")

        if torch.cuda.device_count() > 1:
            model.module.load_state_dict(torch.load(args.output_model_path))
        else:
            model.load_state_dict(torch.load(args.output_model_path))

        evaluate(args, True)