def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) # Path options. parser.add_argument("--pretrained_model_path", default=None, type=str, required=True, help="Path of the pretrained model.") parser.add_argument("--output_model_path", default="./models/tagger_model.bin", type=str, help="Path of the output model.") parser.add_argument("--output_encoder", default="./luke-models/", type=str, help="Path of the output luke model.") parser.add_argument("--suffix_file_encoder", default="encoder", type=str, help="output file suffix luke model.") parser.add_argument("--vocab_path", default="./models/google_vocab.txt", type=str, help="Path of the vocabulary file.") parser.add_argument("--train_path", type=str, required=True, help="Path of the trainset.") parser.add_argument("--dev_path", type=str, required=True, help="Path of the devset.") parser.add_argument("--test_path", type=str, required=True, help="Path of the testset.") parser.add_argument("--config_path", default="./models/google_config.json", type=str, help="Path of the config file.") parser.add_argument("--output_file_prefix", type=str, required=True, help="Prefix for file output.") parser.add_argument("--log_file", default='app.log') # Model options. parser.add_argument("--seq_length", default=256, type=int, help="Sequence length.") parser.add_argument("--classifier", choices=["mlp", "lstm", "lstm_crf", "lstm_ncrf"], default="mlp", help="Classifier type.") parser.add_argument("--bidirectional", action="store_true", help="Specific to recurrent model.") parser.add_argument('--freeze_encoder_weights', action='store_true', help="Enable to freeze the encoder weigths.") # Subword options. parser.add_argument("--subword_type", choices=["none", "char"], default="none", help="Subword feature type.") parser.add_argument("--sub_vocab_path", type=str, default="models/sub_vocab.txt", help="Path of the subword vocabulary file.") parser.add_argument("--subencoder", choices=["avg", "lstm", "gru", "cnn"], default="avg", help="Subencoder type.") parser.add_argument("--sub_layers_num", type=int, default=2, help="The number of subencoder layers.") # Training options. parser.add_argument("--dropout", type=float, default=0.1, help="Dropout.") parser.add_argument("--epochs_num", type=int, default=0, help="Number of epochs.") parser.add_argument("--gradient_accumulation_steps", type=int, default=2, help="Number of steps to accumulate the gradient.") parser.add_argument("--report_steps", type=int, default=200, help="Specific steps to print prompt.") parser.add_argument("--seed", type=int, default=35, help="Random seed.") parser.add_argument("--batch_size", type=int, default=32, help="Batch_size.") parser.add_argument("--num_train_steps", type=int, default=20000, help="Max steps to be trained.") parser.add_argument("--patience", type=int, default=8000, help="Specific steps to wait until stops training.") # Optimizer options. parser.add_argument("--learning_rate", default=1e-5, type=float) parser.add_argument("--lr_schedule", default="warmup_linear", type=str, choices=["warmup_linear", "warmup_constant"]) parser.add_argument("--weight_decay", default=0.01, type=float) parser.add_argument("--max_grad_norm", default=0.0, type=float) parser.add_argument("--adam_b1", default=0.9, type=float) parser.add_argument("--adam_b2", default=0.999, type=float) parser.add_argument("--adam_eps", default=1e-8, type=float) parser.add_argument("--adam_correct_bias", action='store_true') parser.add_argument("--warmup_proportion", default=0.006, type=float) parser.add_argument("--freeze_proportions", default=0.0, type=float) parser.add_argument("--wandb", action='store_true', help="Enable wandb logging") # kg parser.add_argument("--kg_name", type=str, help="KG name or path") parser.add_argument("--use_kg", action='store_true', help="Enable the use of KG.") parser.add_argument("--padding", action='store_true', help="Enable padding.") parser.add_argument( "--truncate", action='store_true', help="Enable truncation if length is more than seq length.") parser.add_argument("--shuffle", action='store_true', help="Enable shuffling during training.") parser.add_argument("--dry_run", action='store_true', help="Dry run to test the implementation.") parser.add_argument( "--voting_choicer", action='store_true', help="Enable the Voting choicer to select the entity type.") parser.add_argument("--eval_kg_tag", action='store_true', help="Enable to include [ENT] tag in evaluation.") parser.add_argument("--use_subword_tag", action='store_true', help="Enable to use separate tag for subword splits.") parser.add_argument("--debug", action='store_true', help="Enable debug.") parser.add_argument("--reverse_order", action='store_true', help="Reverse the feature selection order.") parser.add_argument("--max_entities", default=2, type=int, help="Number of KG features.") parser.add_argument("--eval_range_with_types", action='store_true', help="Enable to eval range with types.") args = parser.parse_args() # Load the hyperparameters of the config file. args = load_hyperparam(args) set_seed(args.seed) logging.basicConfig(filename=args.log_file, filemode='w', format=fmt) labels_map = {"[PAD]": 0, "[ENT]": 1, "[X]": 2, "[CLS]": 3, "[SEP]": 4} begin_ids = [] # Find tagging labels for file in (args.train_path, args.dev_path, args.test_path): with open(file, mode="r", encoding="utf-8") as f: for line_id, line in enumerate(f): if line_id == 0: continue labels = line.strip().split("\t")[0].split() for l in labels: if l not in labels_map: if l.startswith("B") or l.startswith("S"): begin_ids.append(len(labels_map)) # check if I-TAG exists infix = l[1] tag = l[2:] inner_tag = f'I{infix}{tag}' if inner_tag not in labels_map: labels_map[inner_tag] = len(labels_map) labels_map[l] = len(labels_map) idx_to_label = {labels_map[key]: key for key in labels_map} print(begin_ids) print("Labels: ", labels_map) args.labels_num = len(labels_map) # Build knowledge graph. if args.kg_name == 'none': kg_file = [] else: kg_file = args.kg_name # Load Luke model. model_archive = ModelArchive.load(args.pretrained_model_path) tokenizer = model_archive.tokenizer # Handling space character in roberta tokenizer byte_encoder = bytes_to_unicode() byte_decoder = {v: k for k, v in byte_encoder.items()} # Load the pretrained model encoder = LukeModel(model_archive.config) encoder.load_state_dict(model_archive.state_dict, strict=False) kg = KnowledgeGraph(kg_file=kg_file, tokenizer=tokenizer) # For simplicity, we use DataParallel wrapper to use multiple GPUs. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.device = device # Build sequence labeling model. classifiers = { "mlp": LukeTaggerMLP, "lstm": LukeTaggerLSTM, "lstm_crf": LukeTaggerLSTMCRF, "lstm_ncrf": LukeTaggerLSTMNCRF } logger.info(f'The selected classifier is:{classifiers[args.classifier]}') model = classifiers[args.classifier](args, encoder) if torch.cuda.device_count() > 1: print("{} GPUs are available. Let's use them.".format( torch.cuda.device_count())) model = nn.DataParallel(model) model = model.to(device) # Read dataset. def read_dataset(path): dataset = [] count = 0 with open(path, mode="r", encoding="utf8") as f: f.readline() tokens, labels = [], [] for line_id, line in enumerate(f): fields = line.strip().split("\t") if len(fields) == 2: labels, tokens = fields elif len(fields) == 3: labels, tokens, cls = fields else: print( f'The data is not in accepted format at line no:{line_id}.. Ignored' ) continue tokens, pos, vm, tag = kg.add_knowledge_with_vm( args, [tokens], [labels]) tokens = tokens[0] pos = pos[0] vm = vm[0].astype("bool") tag = tag[0] # tokens = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokens + [tokenizer.sep_token]) non_pad_tokens = [ tok for tok in tokens if tok != tokenizer.pad_token ] num_tokens = len(non_pad_tokens) num_pad = len(tokens) - num_tokens labels = [config.CLS_TOKEN ] + labels.split(" ") + [config.SEP_TOKEN] new_labels = [] j = 0 joiner = '-' for i in range(len(tokens)): if tag[i] == 0 and tokens[i] != tokenizer.pad_token: cur_type = labels[j] if cur_type != 'O': try: joiner = cur_type[1] prev_label = cur_type[2:] except: logger.info( f'The label:{cur_type} is converted to O') prev_label = 'O' j += 1 new_labels.append('O') continue else: prev_label = cur_type new_labels.append(cur_type) j += 1 elif tag[i] == 1 and tokens[ i] != tokenizer.pad_token: # 是添加的实体 new_labels.append('[ENT]') elif tag[i] == 2: if prev_label == 'O': new_labels.append('O') else: if args.use_subword_tag: new_labels.append('[X]') else: new_labels.append(f'I{joiner}' + prev_label) else: new_labels.append(PAD_TOKEN) new_labels = [labels_map[l] for l in new_labels] # print(tokens) # print(labels) # print(tag) if num_pad != 0: print(num_pad) exit() mask = [1] * (num_tokens) + [0] * num_pad word_segment_ids = [0] * (len(tokens)) # print(len(tokens)) # print(len(tag)) # exit() # print(tokenizer.pad_token_id) # for i in range(len(tokens)): # if tag[i] == 0 and tokens[i] != tokenizer.pad_token: # new_labels.append(labels[j]) # j += 1 # elif tag[i] == 1 and tokens[i] != tokenizer.pad_token: # 是添加的实体 # new_labels.append(labels_map['[ENT]']) # elif tag[i] == 2: # if args.use_subword_tag: # new_labels.append(labels_map['[X]']) # else: # new_labels.append(labels_map['[ENT]']) # else: # new_labels.append(labels_map[PAD_TOKEN]) # print(labels) # print(new_labels) # print([idx_to_label.get(key) for key in labels]) # print([idx_to_label.get(key) for key in labels]) # print(mask) # print(pos) # print(word_segment_ids) # print(tokens) # tokens = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokens + [tokenizer.sep_token]) tokens = tokenizer.convert_tokens_to_ids(tokens) # print(tokens) # exit() assert len(tokens) == len(new_labels), AssertionError( "The length of token and label is not matching") dataset.append( [tokens, new_labels, mask, pos, vm, tag, word_segment_ids]) # Enable dry rune if args.dry_run: count += 1 if count == 100: break return dataset # Evaluation function. def evaluate(args, is_test, final=False): if is_test: dataset = read_dataset(args.test_path) else: dataset = read_dataset(args.dev_path) instances_num = len(dataset) batch_size = args.batch_size if is_test: logger.info(f"Batch size:{batch_size}") print(f"The number of test instances:{instances_num}") true_labels_all = [] predicted_labels_all = [] confusion = torch.zeros(len(labels_map), len(labels_map), dtype=torch.long) model.eval() test_batcher = Batcher(args, dataset, token_pad=tokenizer.pad_token_id, label_pad=labels_map[PAD_TOKEN]) for i, (input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, segment_ids_batch) in enumerate(test_batcher): input_ids_batch = input_ids_batch.to(device) label_ids_batch = label_ids_batch.to(device) mask_ids_batch = mask_ids_batch.to(device) pos_ids_batch = pos_ids_batch.to(device) vm_ids_batch = vm_ids_batch.long().to(device) segment_ids_batch = segment_ids_batch.long().to(device) pred, logits, scores = model(input_ids_batch, segment_ids_batch, mask_ids_batch, label_ids_batch, pos_ids_batch, vm_ids_batch, use_kg=args.use_kg) for pred_sample, gold_sample, mask in zip(pred, label_ids_batch, mask_ids_batch): pred_labels = [ idx_to_label.get(key) for key in pred_sample.tolist() ] gold_labels = [ idx_to_label.get(key) for key in gold_sample.tolist() ] num_labels = sum(mask) # Exclude the [CLS], and [SEP] tokens pred_labels = pred_labels[1:num_labels - 1] true_labels = gold_labels[1:num_labels - 1] pred_labels = [p.replace('_NOKG', '') for p in pred_labels] true_labels = [t.replace('_NOKG', '') for t in true_labels] true_labels, pred_labels = filter_kg_labels( true_labels, pred_labels) pred_labels = [p.replace('_', '-') for p in pred_labels] true_labels = [t.replace('_', '-') for t in true_labels] biluo_tags_predicted = get_bio(pred_labels) biluo_tags_true = get_bio(true_labels) if len(biluo_tags_predicted) != len(biluo_tags_true): logger.error( 'The length of the predicted labels is not same as that of true labels..' ) exit() predicted_labels_all.append(biluo_tags_predicted) true_labels_all.append(biluo_tags_true) if final: with open(f'{args.output_file_prefix}_predictions.txt', 'a') as p, \ open(f'{args.output_file_prefix}_gold.txt', 'a') as g: p.write('\n'.join([' '.join(l) for l in predicted_labels_all])) g.write('\n'.join([' '.join(l) for l in true_labels_all])) return dict( f1=seqeval.metrics.f1_score(true_labels_all, predicted_labels_all), precision=seqeval.metrics.precision_score(true_labels_all, predicted_labels_all), recall=seqeval.metrics.recall_score(true_labels_all, predicted_labels_all), f1_span=f1_score_span(true_labels_all, predicted_labels_all), precision_span=precision_score_span(true_labels_all, predicted_labels_all), recall_span=recall_score_span(true_labels_all, predicted_labels_all), ) # Training phase. logger.info("Start training.") instances = read_dataset(args.train_path) instances_num = len(instances) batch_size = args.batch_size if args.epochs_num: args.num_train_steps = int( instances_num * args.epochs_num / batch_size) + 1 unfreeze_steps = 0 model_frozen = False if args.freeze_proportions != 0.0: unfreeze_steps = int( args.num_train_steps * args.freeze_proportions) + 1 logger.info( f'Two phase training is enabled with model unfreeze at:{unfreeze_steps}' ) # freeze the model model.freeze() model_frozen = True logger.info(f"Batch size:{batch_size}") logger.info(f"The number of training instances:{instances_num}") train_batcher = Batcher(args, instances, token_pad=tokenizer.pad_token_id, label_pad=labels_map[PAD_TOKEN]) optimizer = create_optimizer(args, model) scheduler = create_scheduler(args, optimizer) total_loss = 0. best_f1 = 0.0 # Dry evaluate # evaluate(args, True) def maybe_no_sync(step): if (hasattr(model, "no_sync") and (step + 1) % args.gradient_accumulation_steps != 0): return model.no_sync() else: return contextlib.ExitStack() # YOU MUST LOG INTO WANDB WITH YOUR OWN ACCOUNT if args.wandb: import wandb wandb.init(project="kbert_pretrain") # args.update(wandb.config) print(f'new args{args}') else: wandb = None global_steps = 0 early_stop_steps = 0 epoch = 0 with tqdm(total=args.num_train_steps) as pbar: while True: model.train() for step, (input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, segment_ids_batch) in enumerate(train_batcher): input_ids_batch = input_ids_batch.to(device) label_ids_batch = label_ids_batch.to(device) mask_ids_batch = mask_ids_batch.to(device) pos_ids_batch = pos_ids_batch.to(device) vm_ids_batch = vm_ids_batch.long().to(device) segment_ids_batch = segment_ids_batch.long().to(device) loss, logits = model.score(input_ids_batch, segment_ids_batch, mask_ids_batch, label_ids_batch, pos_ids_batch, vm_ids_batch, use_kg=args.use_kg) if torch.cuda.device_count() > 1: loss = torch.mean(loss) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps with maybe_no_sync(step): loss.backward() total_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.max_grad_norm != 0.0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() model.zero_grad() optimizer.zero_grad() pbar.set_description("epoch: %d loss: %.7f" % (epoch, loss.item())) pbar.update() global_steps += 1 if global_steps % args.report_steps == 0: logger.info("Epoch id: {}, Global Steps:{}, Avg loss: " "{:.10f}".format( epoch, global_steps + 1, total_loss / args.report_steps)) # Evaluation phase. logger.info("Start evaluate on dev dataset.") results = evaluate(args, False) logger.info(results) logger.info("Start evaluation on test dataset.") results_test = evaluate(args, True) logger.info(results_test) avg_loss = total_loss / args.report_steps if args.wandb: # Log the loss and accuracy values at the end of each epoch wandb.log({ "steps": global_steps, "train Loss": avg_loss, "valid_acc": results['f1'], "test_acc": results_test['f1'], "learning_rate": args.learning_rate, "batch_size": args.batch_size, "lr_schedule": args.lr_schedule, "weight_decay": args.weight_decay, "max_grad_norm": args.max_grad_norm, }) if results['f1'] > best_f1: best_f1 = results['f1'] early_stop_steps = 0 save_model(model, args.output_model_path) save_encoder(args, encoder, suffix=args.suffix_file_encoder) else: early_stop_steps += args.report_steps # Change back the model for training model.train() total_loss = 0. if model_frozen and global_steps >= unfreeze_steps: # unfreeze the model and start training logger.info('The encoder is unfrozen for training.') model.unfreeze() model_frozen = False if global_steps >= args.num_train_steps: # Training completed logger.info('The training is completed!') break if early_stop_steps >= args.patience: # Early stopping logger.info('The early stopping is triggered!') break if model_frozen and global_steps >= unfreeze_steps: # unfreeze the model and start training logger.info('The encoder is unfrozen for training.') model.unfreeze() model_frozen = False if global_steps >= args.num_train_steps: # Training completed logger.info('The training is completed!') break if early_stop_steps >= args.patience: # Early stopping logger.info('The early stopping is triggered!') break epoch += 1 # Evaluation phase. logger.info("Final evaluation on test dataset.") if torch.cuda.device_count() > 1: model.module.load_state_dict(torch.load(args.output_model_path)) else: model.load_state_dict(torch.load(args.output_model_path)) results_final = evaluate(args, True, final=True) logger.info(results_final)
from brain.knowgraph_english import KnowledgeGraph vocab_file = "D:\Downloads\ent_vocab_custom" kg = KnowledgeGraph(kg_file=vocab_file, predicate=True) text = "Delhi is the capital of India ." tokens, pos, vm, tag = kg.add_knowledge_with_vm([text], add_pad=False, max_length=16) print(tag) print(pos) print(tokens) print(vm)
def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) # Path options. parser.add_argument("--pretrained_model_path", default=None, type=str, help="Path of the pretrained model.") parser.add_argument("--output_model_path", default="./models/tagger_model.bin", type=str, help="Path of the output model.") parser.add_argument("--output_encoder", default="./luke-models/", type=str, help="Path of the output luke model.") parser.add_argument("--suffix_file_encoder", default="encoder", type=str, help="output file suffix luke model.") parser.add_argument("--vocab_path", default="./models/google_vocab.txt", type=str, help="Path of the vocabulary file.") parser.add_argument("--train_path", type=str, required=True, help="Path of the trainset.") parser.add_argument("--dev_path", type=str, required=True, help="Path of the devset.") parser.add_argument("--test_path", type=str, required=True, help="Path of the testset.") parser.add_argument("--config_path", default="./models/google_config.json", type=str, help="Path of the config file.") parser.add_argument("--output_file_prefix", type=str, required=True, help="Prefix for file output.") # Model options. parser.add_argument("--batch_size", type=int, default=2, help="Batch_size.") parser.add_argument("--seq_length", default=256, type=int, help="Sequence length.") parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \ "cnn", "gatedcnn", "attn", \ "rcnn", "crnn", "gpt", "bilstm"], \ default="bert", help="Encoder type.") parser.add_argument("--bidirectional", action="store_true", help="Specific to recurrent model.") # Subword options. parser.add_argument("--subword_type", choices=["none", "char"], default="none", help="Subword feature type.") parser.add_argument("--sub_vocab_path", type=str, default="models/sub_vocab.txt", help="Path of the subword vocabulary file.") parser.add_argument("--subencoder", choices=["avg", "lstm", "gru", "cnn"], default="avg", help="Subencoder type.") parser.add_argument("--sub_layers_num", type=int, default=2, help="The number of subencoder layers.") # Optimizer options. parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate.") parser.add_argument("--warmup", type=float, default=0.1, help="Warm up value.") # Training options. parser.add_argument("--dropout", type=float, default=0.1, help="Dropout.") parser.add_argument("--epochs_num", type=int, default=5, help="Number of epochs.") parser.add_argument("--report_steps", type=int, default=2, help="Specific steps to print prompt.") parser.add_argument("--seed", type=int, default=7, help="Random seed.") # kg parser.add_argument("--kg_name", required=True, help="KG name or path") parser.add_argument("--use_kg", action='store_true', help="Enable the use of KG.") parser.add_argument("--dry_run", action='store_true', help="Dry run to test the implementation.") parser.add_argument( "--voting_choicer", action='store_true', help="Enable the Voting choicer to select the entity type.") parser.add_argument("--eval_kg_tag", action='store_true', help="Enable to include [ENT] tag in evaluation.") parser.add_argument("--use_subword_tag", action='store_true', help="Enable to use separate tag for subword splits.") parser.add_argument("--debug", action='store_true', help="Enable debug.") parser.add_argument("--reverse_order", action='store_true', help="Reverse the feature selection order.") parser.add_argument("--max_entities", default=2, type=int, help="Number of KG features.") parser.add_argument("--eval_range_with_types", action='store_true', help="Enable to eval range with types.") args = parser.parse_args() # Load the hyperparameters of the config file. args = load_hyperparam(args) set_seed(args.seed) labels_map = {"[PAD]": 0, "[ENT]": 1, "[X]": 2, "[CLS]": 3, "[SEP]": 4} begin_ids = [] # Find tagging labels with open(args.train_path, mode="r", encoding="utf-8") as f: for line_id, line in enumerate(f): if line_id == 0: continue labels = line.strip().split("\t")[0].split() for l in labels: if l not in labels_map: if l.startswith("B") or l.startswith("S"): begin_ids.append(len(labels_map)) labels_map[l] = len(labels_map) idx_to_label = {labels_map[key]: key for key in labels_map} print(begin_ids) print("Labels: ", labels_map) args.labels_num = len(labels_map) # Build knowledge graph. if args.kg_name == 'none': kg_file = [] else: kg_file = args.kg_name # Load Luke model. model_archive = ModelArchive.load(args.pretrained_model_path) tokenizer = model_archive.tokenizer # Handling space character in roberta tokenizer byte_encoder = bytes_to_unicode() byte_decoder = {v: k for k, v in byte_encoder.items()} # Load the pretrained model encoder = LukeModel(model_archive.config) encoder.load_state_dict(model_archive.state_dict, strict=False) # Build sequence labeling model. model = LukeTagger(args, encoder) kg = KnowledgeGraph(kg_file=kg_file, tokenizer=tokenizer) # For simplicity, we use DataParallel wrapper to use multiple GPUs. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1: print("{} GPUs are available. Let's use them.".format( torch.cuda.device_count())) model = nn.DataParallel(model) model = model.to(device) # Datset loader. def batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vm_ids, tag_ids, segment_ids): instances_num = input_ids.size()[0] for i in range(instances_num // batch_size): input_ids_batch = input_ids[i * batch_size:(i + 1) * batch_size, :] label_ids_batch = label_ids[i * batch_size:(i + 1) * batch_size, :] mask_ids_batch = mask_ids[i * batch_size:(i + 1) * batch_size, :] pos_ids_batch = pos_ids[i * batch_size:(i + 1) * batch_size, :] vm_ids_batch = vm_ids[i * batch_size:(i + 1) * batch_size, :, :] tag_ids_batch = tag_ids[i * batch_size:(i + 1) * batch_size, :] segment_ids_batch = segment_ids[i * batch_size:(i + 1) * batch_size, :] yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch, segment_ids_batch if instances_num > instances_num // batch_size * batch_size: input_ids_batch = input_ids[instances_num // batch_size * batch_size:, :] label_ids_batch = label_ids[instances_num // batch_size * batch_size:, :] mask_ids_batch = mask_ids[instances_num // batch_size * batch_size:, :] pos_ids_batch = pos_ids[instances_num // batch_size * batch_size:, :] vm_ids_batch = vm_ids[instances_num // batch_size * batch_size:, :, :] tag_ids_batch = tag_ids[instances_num // batch_size * batch_size:, :] segment_ids_batch = segment_ids[instances_num // batch_size * batch_size:, :] yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch, segment_ids_batch # Read dataset. def read_dataset(path): dataset = [] count = 0 with open(path, mode="r", encoding="utf8") as f: f.readline() tokens, labels = [], [] for line_id, line in enumerate(f): fields = line.strip().split("\t") if len(fields) == 2: labels, tokens = fields elif len(fields) == 3: labels, tokens, cls = fields else: print( f'The data is not in accepted format at line no:{line_id}.. Ignored' ) continue tokens, pos, vm, tag = \ kg.add_knowledge_with_vm([tokens], [labels], use_kg=args.use_kg, max_length=args.seq_length, max_entities=args.max_entities, reverse_order=args.reverse_order) tokens = tokens[0] pos = pos[0] vm = vm[0].astype("bool") tag = tag[0] # tokens = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokens + [tokenizer.sep_token]) non_pad_tokens = [ tok for tok in tokens if tok != tokenizer.pad_token ] num_tokens = len(non_pad_tokens) num_pad = len(tokens) - num_tokens labels = [config.CLS_TOKEN ] + labels.split(" ") + [config.SEP_TOKEN] new_labels = [] j = 0 joiner = '-' for i in range(len(tokens)): if tag[i] == 0 and tokens[i] != tokenizer.pad_token: cur_type = labels[j] new_labels.append(cur_type) if cur_type != 'O': joiner = cur_type[1] prev_label = cur_type[2:] else: prev_label = cur_type j += 1 elif tag[i] == 1 and tokens[ i] != tokenizer.pad_token: # 是添加的实体 new_labels.append('[ENT]') elif tag[i] == 2: if prev_label == 'O': new_labels.append('O') else: if args.use_subword_tag: new_labels.append('[X]') else: new_labels.append(f'I{joiner}' + prev_label) else: new_labels.append(PAD_TOKEN) new_labels = [labels_map[l] for l in new_labels] # print(tokens) # print(labels) # print(tag) mask = [1] * (num_tokens) + [0] * num_pad word_segment_ids = [0] * (len(tokens)) # print(len(tokens)) # print(len(tag)) # exit() # print(tokenizer.pad_token_id) # for i in range(len(tokens)): # if tag[i] == 0 and tokens[i] != tokenizer.pad_token: # new_labels.append(labels[j]) # j += 1 # elif tag[i] == 1 and tokens[i] != tokenizer.pad_token: # 是添加的实体 # new_labels.append(labels_map['[ENT]']) # elif tag[i] == 2: # if args.use_subword_tag: # new_labels.append(labels_map['[X]']) # else: # new_labels.append(labels_map['[ENT]']) # else: # new_labels.append(labels_map[PAD_TOKEN]) # print(labels) # print(new_labels) # print([idx_to_label.get(key) for key in labels]) # print([idx_to_label.get(key) for key in labels]) # print(mask) # print(pos) # print(word_segment_ids) # print(tokens) # tokens = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokens + [tokenizer.sep_token]) tokens = tokenizer.convert_tokens_to_ids(tokens) # print(tokens) # exit() assert len(tokens) == len(new_labels), AssertionError( "The length of token and label is not matching") dataset.append( [tokens, new_labels, mask, pos, vm, tag, word_segment_ids]) # Enable dry rune if args.dry_run: count += 1 if count == 100: break return dataset # Evaluation function. def evaluate(args, is_test, final=False): if is_test: dataset = read_dataset(args.test_path) else: dataset = read_dataset(args.dev_path) input_ids = torch.LongTensor([sample[0] for sample in dataset]) label_ids = torch.LongTensor([sample[1] for sample in dataset]) mask_ids = torch.LongTensor([sample[2] for sample in dataset]) pos_ids = torch.LongTensor([sample[3] for sample in dataset]) vm_ids = torch.BoolTensor([sample[4] for sample in dataset]) tag_ids = torch.LongTensor([sample[5] for sample in dataset]) segment_ids = torch.LongTensor([sample[6] for sample in dataset]) instances_num = input_ids.size(0) batch_size = args.batch_size if is_test: print("Batch size: ", batch_size) print("The number of test instances:", instances_num) correct = 0 correct_with_type = 0 gold_entities_num = 0 pred_entities_num = 0 confusion = torch.zeros(len(labels_map), len(labels_map), dtype=torch.long) model.eval() for i, (input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch, segment_ids_batch) in enumerate( batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vm_ids, tag_ids, segment_ids)): input_ids_batch = input_ids_batch.to(device) label_ids_batch = label_ids_batch.to(device) mask_ids_batch = mask_ids_batch.to(device) pos_ids_batch = pos_ids_batch.to(device) tag_ids_batch = tag_ids_batch.to(device) vm_ids_batch = vm_ids_batch.long().to(device) segment_ids_batch = segment_ids_batch.long().to(device) loss, _, pred, gold, _ = model(input_ids_batch, segment_ids_batch, mask_ids_batch, label_ids_batch, pos_ids_batch, vm_ids_batch, use_kg=args.use_kg) if final: with open(f'{args.output_file_prefix}_predictions.txt', 'a') as p, \ open(f'{args.output_file_prefix}_gold.txt', 'a') as g, \ open(f'{args.output_file_prefix}_text.txt', 'a') as t: predicted_labels = [ idx_to_label.get(key) for key in pred.tolist() ] gold_labels = [ idx_to_label.get(key) for key in gold.tolist() ] num_tokens = len(predicted_labels) mask_ids_batch = mask_ids_batch.view(-1, num_tokens) masks = mask_ids_batch.tolist()[0] input_ids_batch = input_ids_batch.view(-1, num_tokens) tokens = input_ids_batch.tolist()[0] for start_idx in range(0, num_tokens, args.seq_length): pred_sample = predicted_labels[start_idx:start_idx + args.seq_length] gold_sample = gold_labels[start_idx:start_idx + args.seq_length] mask = masks[start_idx:start_idx + args.seq_length] num_labels = sum(mask) token_sample = tokens[start_idx:start_idx + args.seq_length] token_sample = token_sample[:num_labels] text = ''.join( tokenizer.convert_ids_to_tokens(token_sample)) text = bytearray([byte_decoder[c] for c in text]).decode('utf-8') p.write(' '.join(pred_sample[:num_labels]) + '\n') g.write(' '.join(gold_sample[:num_labels]) + '\n') t.write(text + '\n') for j in range(gold.size()[0]): if gold[j].item() in begin_ids: gold_entities_num += 1 for j in range(pred.size()[0]): if pred[j].item( ) in begin_ids and gold[j].item() != labels_map["[PAD]"]: pred_entities_num += 1 pred_entities_pos = [] pred_entities_pos_with_type = [] gold_entities_pos = [] gold_entities_pos_with_type = [] start, end = 0, 0 for j in range(gold.size()[0]): if gold[j].item() in begin_ids: start = j for k in range(j + 1, gold.size()[0]): if gold[k].item() == labels_map['[X]'] or gold[k].item( ) == labels_map['[ENT]']: continue if gold[k].item( ) == labels_map["[PAD]"] or gold[k].item( ) == labels_map["O"] or gold[k].item() in begin_ids: end = k - 1 break else: end = gold.size()[0] - 1 if args.eval_range_with_types: ent_type_gold = idx_to_label.get(gold[start].item()) ent_type_gold = ent_type_gold.replace('_NOKG', '') gold_entities_pos_with_type.append( (start, end, ent_type_gold)) gold_entities_pos.append((start, end)) for j in range(pred.size()[0]): if pred[j].item() in begin_ids and gold[j].item() != labels_map["[PAD]"] and gold[j].item() != \ labels_map["[ENT]"] and gold[j].item() != labels_map["[X]"]: start = j for k in range(j + 1, pred.size()[0]): if pred[k].item() == labels_map['[X]'] or gold[k].item( ) == labels_map['[ENT]']: continue if pred[k].item( ) == labels_map["[PAD]"] or pred[k].item( ) == labels_map["O"] or pred[k].item() in begin_ids: end = k - 1 break else: end = pred.size()[0] - 1 if args.eval_range_with_types: # Get all the labels in the range if start == end: entity_types = [ idx_to_label.get(l.item()) for l in [pred[start]] ] else: entity_types = [ idx_to_label.get(l.item()) for l in pred[start:end] ] # Run voting choicer final_entity_type = voting_choicer(entity_types) final_entity_type = final_entity_type.replace( '_NOKG', '') if final: logger.info( f'Predicted: {" ".join(entity_types)}, Selected: {final_entity_type}' ) if args.voting_choicer: # Convert back to label id and add in the tuple pred_entities_pos_with_type.append( (start, end, final_entity_type)) else: # Use the first prediction ent_type_pred = idx_to_label.get( pred[start].item()) ent_type_pred = ent_type_pred.replace('_NOKG', '') pred_entities_pos_with_type.append( (start, end, ent_type_pred)) pred_entities_pos.append((start, end)) for entity in pred_entities_pos: if entity not in gold_entities_pos: continue else: correct += 1 if args.eval_range_with_types: for entity in pred_entities_pos_with_type: if entity not in gold_entities_pos_with_type: continue else: correct_with_type += 1 try: print("Report precision, recall, and f1:") p = correct / pred_entities_num r = correct / gold_entities_num f1 = 2 * p * r / (p + r) print("{:.3f}, {:.3f}, {:.3f}".format(p, r, f1)) if args.eval_range_with_types: try: print( "Report accuracy with type, precision, recall, and f1:" ) p_with_type = correct_with_type / pred_entities_num r_with_type = correct_with_type / gold_entities_num f1_with_type = 2 * p_with_type * r_with_type / ( p_with_type + r_with_type) print("{:.3f}, {:.3f}, {:.3f}".format( p_with_type, r_with_type, f1_with_type)) except: pass return f1 except ZeroDivisionError: return 0 # Training phase. print("Start training.") instances = read_dataset(args.train_path) input_ids = torch.LongTensor([ins[0] for ins in instances]) label_ids = torch.LongTensor([ins[1] for ins in instances]) mask_ids = torch.LongTensor([ins[2] for ins in instances]) pos_ids = torch.LongTensor([ins[3] for ins in instances]) vm_ids = torch.BoolTensor([ins[4] for ins in instances]) tag_ids = torch.LongTensor([ins[5] for ins in instances]) segment_ids = torch.LongTensor([ins[6] for ins in instances]) instances_num = input_ids.size(0) batch_size = args.batch_size train_steps = int(instances_num * args.epochs_num / batch_size) + 1 train_batcher = Batcher(batch_size, input_ids, label_ids, mask_ids, pos_ids, vm_ids, tag_ids, segment_ids) print("Batch size: ", batch_size) print("The number of training instances:", instances_num) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0 }] optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup, t_total=train_steps) total_loss = 0. f1 = 0.0 best_f1 = 0.0 # Dry evaluate # evaluate(args, True) for epoch in range(1, args.epochs_num + 1): model.train() for i, (input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch, segment_ids_batch) in enumerate(train_batcher): model.zero_grad() input_ids_batch = input_ids_batch.to(device) label_ids_batch = label_ids_batch.to(device) mask_ids_batch = mask_ids_batch.to(device) pos_ids_batch = pos_ids_batch.to(device) tag_ids_batch = tag_ids_batch.to(device) vm_ids_batch = vm_ids_batch.long().to(device) segment_ids_batch = segment_ids_batch.long().to(device) loss, _, _, _, _ = model(input_ids_batch, segment_ids_batch, mask_ids_batch, label_ids_batch, pos_ids_batch, vm_ids_batch, use_kg=args.use_kg) if torch.cuda.device_count() > 1: loss = torch.mean(loss) total_loss += loss.item() if (i + 1) % args.report_steps == 0: print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}". format(epoch, i + 1, total_loss / args.report_steps)) total_loss = 0. loss.backward() optimizer.step() # Evaluation phase. print("Start evaluate on dev dataset.") f1 = evaluate(args, False) print("Start evaluation on test dataset.") evaluate(args, True) if f1 > best_f1: best_f1 = f1 save_model(model, args.output_model_path) save_encoder(args, encoder, suffix=args.suffix_file_encoder) else: continue # Evaluation phase. print("Final evaluation on test dataset.") if torch.cuda.device_count() > 1: model.module.load_state_dict(torch.load(args.output_model_path)) else: model.load_state_dict(torch.load(args.output_model_path)) evaluate(args, True, final=True)
def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) # Path options. parser.add_argument("--pretrained_model_path", default=None, type=str, help="Path of the pretrained model.") parser.add_argument("--output_model_path", default="./models/tagger_model.bin", type=str, help="Path of the output model.") parser.add_argument("--output_encoder", default="./luke-models/", type=str, help="Path of the output luke model.") parser.add_argument("--suffix_file_encoder", default="encoder", type=str, help="output file suffix luke model.") parser.add_argument("--vocab_path", default="./models/google_vocab.txt", type=str, help="Path of the vocabulary file.") parser.add_argument("--train_path", type=str, required=True, help="Path of the trainset.") parser.add_argument("--dev_path", type=str, required=True, help="Path of the devset.") parser.add_argument("--test_path", type=str, required=True, help="Path of the testset.") parser.add_argument("--config_path", default="./models/google_config.json", type=str, help="Path of the config file.") parser.add_argument("--output_file_prefix", type=str, required=True, help="Prefix for file output.") parser.add_argument("--log_file", default='app.log') # Model options. parser.add_argument("--batch_size", type=int, default=2, help="Batch_size.") parser.add_argument("--seq_length", default=256, type=int, help="Sequence length.") parser.add_argument("--classifier", choices=["mlp", "lstm", "lstm_crf", "lstm_ncrf"], default="mlp", help="Classifier type.") parser.add_argument("--bidirectional", action="store_true", help="Specific to recurrent model.") parser.add_argument('--freeze_encoder_weights', action='store_true', help="Enable to freeze the encoder weigths.") # Subword options. parser.add_argument("--subword_type", choices=["none", "char"], default="none", help="Subword feature type.") parser.add_argument("--sub_vocab_path", type=str, default="models/sub_vocab.txt", help="Path of the subword vocabulary file.") parser.add_argument("--subencoder", choices=["avg", "lstm", "gru", "cnn"], default="avg", help="Subencoder type.") parser.add_argument("--sub_layers_num", type=int, default=2, help="The number of subencoder layers.") # Optimizer options. parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate.") parser.add_argument("--schedule_lr", action='store_true', help="Enable to use lr scheduler.") parser.add_argument("--warmup", type=float, default=0.1, help="Warm up value.") # Training options. parser.add_argument("--dropout", type=float, default=0.1, help="Dropout.") parser.add_argument("--epochs_num", type=int, default=5, help="Number of epochs.") parser.add_argument("--report_steps", type=int, default=2, help="Specific steps to print prompt.") parser.add_argument("--seed", type=int, default=35, help="Random seed.") # kg parser.add_argument("--kg_name", required=True, help="KG name or path") parser.add_argument("--use_kg", action='store_true', help="Enable the use of KG.") parser.add_argument("--dry_run", action='store_true', help="Dry run to test the implementation.") parser.add_argument( "--voting_choicer", action='store_true', help="Enable the Voting choicer to select the entity type.") parser.add_argument("--eval_kg_tag", action='store_true', help="Enable to include [ENT] tag in evaluation.") parser.add_argument("--use_subword_tag", action='store_true', help="Enable to use separate tag for subword splits.") parser.add_argument("--debug", action='store_true', help="Enable debug.") parser.add_argument("--reverse_order", action='store_true', help="Reverse the feature selection order.") parser.add_argument("--max_entities", default=2, type=int, help="Number of KG features.") parser.add_argument("--eval_range_with_types", action='store_true', help="Enable to eval range with types.") args = parser.parse_args() # Load the hyperparameters of the config file. args = load_hyperparam(args) set_seed(args.seed) logging.basicConfig(filename=args.log_file, filemode='w', format=fmt) labels_map = {"[PAD]": 0, "[ENT]": 1, "[X]": 2, "[CLS]": 3, "[SEP]": 4} begin_ids = [] # Find tagging labels for file in (args.train_path, args.dev_path, args.test_path): with open(file, mode="r", encoding="utf-8") as f: for line_id, line in enumerate(f): if line_id == 0: continue labels = line.strip().split("\t")[0].split() for l in labels: if l not in labels_map: if l.startswith("B") or l.startswith("S"): begin_ids.append(len(labels_map)) # check if I-TAG exists infix = l[1] tag = l[2:] inner_tag = f'I{infix}{tag}' if inner_tag not in labels_map: labels_map[inner_tag] = len(labels_map) labels_map[l] = len(labels_map) idx_to_label = {labels_map[key]: key for key in labels_map} print(begin_ids) print("Labels: ", labels_map) args.labels_num = len(labels_map) # Build knowledge graph. if args.kg_name == 'none': kg_file = [] else: kg_file = args.kg_name # Load Luke model. model_archive = ModelArchive.load(args.pretrained_model_path) tokenizer = model_archive.tokenizer # Handling space character in roberta tokenizer byte_encoder = bytes_to_unicode() byte_decoder = {v: k for k, v in byte_encoder.items()} # Load the pretrained model encoder = LukeModel(model_archive.config) encoder.load_state_dict(model_archive.state_dict, strict=False) kg = KnowledgeGraph(kg_file=kg_file, tokenizer=tokenizer) # For simplicity, we use DataParallel wrapper to use multiple GPUs. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.device = device # Build sequence labeling model. classifiers = { "mlp": LukeTaggerMLP, "lstm": LukeTaggerLSTM, "lstm_crf": LukeTaggerLSTMCRF, "lstm_ncrf": LukeTaggerLSTMNCRF } logger.info(f'The selected classifier is:{classifiers[args.classifier]}') model = classifiers[args.classifier](args, encoder) if torch.cuda.device_count() > 1: print("{} GPUs are available. Let's use them.".format( torch.cuda.device_count())) model = nn.DataParallel(model) model = model.to(device) # Datset loader. def batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vm_ids, tag_ids, segment_ids): instances_num = input_ids.size()[0] for i in range(instances_num // batch_size): input_ids_batch = input_ids[i * batch_size:(i + 1) * batch_size, :] label_ids_batch = label_ids[i * batch_size:(i + 1) * batch_size, :] mask_ids_batch = mask_ids[i * batch_size:(i + 1) * batch_size, :] pos_ids_batch = pos_ids[i * batch_size:(i + 1) * batch_size, :] vm_ids_batch = vm_ids[i * batch_size:(i + 1) * batch_size, :, :] tag_ids_batch = tag_ids[i * batch_size:(i + 1) * batch_size, :] segment_ids_batch = segment_ids[i * batch_size:(i + 1) * batch_size, :] yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch, segment_ids_batch if instances_num > instances_num // batch_size * batch_size: input_ids_batch = input_ids[instances_num // batch_size * batch_size:, :] label_ids_batch = label_ids[instances_num // batch_size * batch_size:, :] mask_ids_batch = mask_ids[instances_num // batch_size * batch_size:, :] pos_ids_batch = pos_ids[instances_num // batch_size * batch_size:, :] vm_ids_batch = vm_ids[instances_num // batch_size * batch_size:, :, :] tag_ids_batch = tag_ids[instances_num // batch_size * batch_size:, :] segment_ids_batch = segment_ids[instances_num // batch_size * batch_size:, :] yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch, segment_ids_batch # Read dataset. def read_dataset(path): dataset = [] count = 0 with open(path, mode="r", encoding="utf8") as f: f.readline() tokens, labels = [], [] for line_id, line in enumerate(f): fields = line.strip().split("\t") if len(fields) == 2: labels, tokens = fields elif len(fields) == 3: labels, tokens, cls = fields else: print( f'The data is not in accepted format at line no:{line_id}.. Ignored' ) continue tokens, pos, vm, tag = \ kg.add_knowledge_with_vm([tokens], [labels], use_kg=args.use_kg, max_length=args.seq_length, max_entities=args.max_entities, reverse_order=args.reverse_order) tokens = tokens[0] pos = pos[0] vm = vm[0].astype("bool") tag = tag[0] # tokens = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokens + [tokenizer.sep_token]) non_pad_tokens = [ tok for tok in tokens if tok != tokenizer.pad_token ] num_tokens = len(non_pad_tokens) num_pad = len(tokens) - num_tokens labels = [config.CLS_TOKEN ] + labels.split(" ") + [config.SEP_TOKEN] new_labels = [] j = 0 joiner = '-' for i in range(len(tokens)): if tag[i] == 0 and tokens[i] != tokenizer.pad_token: cur_type = labels[j] if cur_type != 'O': try: joiner = cur_type[1] prev_label = cur_type[2:] except: logger.info( f'The label:{cur_type} is converted to O') prev_label = 'O' j += 1 new_labels.append('O') continue else: prev_label = cur_type new_labels.append(cur_type) j += 1 elif tag[i] == 1 and tokens[ i] != tokenizer.pad_token: # 是添加的实体 new_labels.append('[ENT]') elif tag[i] == 2: if prev_label == 'O': new_labels.append('O') else: if args.use_subword_tag: new_labels.append('[X]') else: new_labels.append(f'I{joiner}' + prev_label) else: new_labels.append(PAD_TOKEN) new_labels = [labels_map[l] for l in new_labels] # print(tokens) # print(labels) # print(tag) mask = [1] * (num_tokens) + [0] * num_pad word_segment_ids = [0] * (len(tokens)) # print(len(tokens)) # print(len(tag)) # exit() # print(tokenizer.pad_token_id) # for i in range(len(tokens)): # if tag[i] == 0 and tokens[i] != tokenizer.pad_token: # new_labels.append(labels[j]) # j += 1 # elif tag[i] == 1 and tokens[i] != tokenizer.pad_token: # 是添加的实体 # new_labels.append(labels_map['[ENT]']) # elif tag[i] == 2: # if args.use_subword_tag: # new_labels.append(labels_map['[X]']) # else: # new_labels.append(labels_map['[ENT]']) # else: # new_labels.append(labels_map[PAD_TOKEN]) # print(labels) # print(new_labels) # print([idx_to_label.get(key) for key in labels]) # print([idx_to_label.get(key) for key in labels]) # print(mask) # print(pos) # print(word_segment_ids) # print(tokens) # tokens = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokens + [tokenizer.sep_token]) tokens = tokenizer.convert_tokens_to_ids(tokens) # print(tokens) # exit() assert len(tokens) == len(new_labels), AssertionError( "The length of token and label is not matching") dataset.append( [tokens, new_labels, mask, pos, vm, tag, word_segment_ids]) # Enable dry rune if args.dry_run: count += 1 if count == 100: break return dataset # Evaluation function. def evaluate(args, is_test, final=False): if is_test: dataset = read_dataset(args.test_path) else: dataset = read_dataset(args.dev_path) input_ids = torch.LongTensor([sample[0] for sample in dataset]) label_ids = torch.LongTensor([sample[1] for sample in dataset]) mask_ids = torch.LongTensor([sample[2] for sample in dataset]) pos_ids = torch.LongTensor([sample[3] for sample in dataset]) vm_ids = torch.BoolTensor([sample[4] for sample in dataset]) tag_ids = torch.LongTensor([sample[5] for sample in dataset]) segment_ids = torch.LongTensor([sample[6] for sample in dataset]) instances_num = input_ids.size(0) batch_size = args.batch_size if is_test: logger.info(f"Batch size:{batch_size}") print(f"The number of test instances:{instances_num}") true_labels_all = [] predicted_labels_all = [] confusion = torch.zeros(len(labels_map), len(labels_map), dtype=torch.long) model.eval() for i, (input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch, segment_ids_batch) in enumerate( batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vm_ids, tag_ids, segment_ids)): input_ids_batch = input_ids_batch.to(device) label_ids_batch = label_ids_batch.to(device) mask_ids_batch = mask_ids_batch.to(device) pos_ids_batch = pos_ids_batch.to(device) tag_ids_batch = tag_ids_batch.to(device) vm_ids_batch = vm_ids_batch.long().to(device) segment_ids_batch = segment_ids_batch.long().to(device) pred = model(input_ids_batch, segment_ids_batch, mask_ids_batch, label_ids_batch, pos_ids_batch, vm_ids_batch, use_kg=args.use_kg) for pred_sample, gold_sample, mask in zip(pred, label_ids_batch, mask_ids_batch): pred_labels = [ idx_to_label.get(key) for key in pred_sample.tolist() ] gold_labels = [ idx_to_label.get(key) for key in gold_sample.tolist() ] num_labels = sum(mask) # Exclude the [CLS], and [SEP] tokens pred_labels = pred_labels[1:num_labels - 1] true_labels = gold_labels[1:num_labels - 1] pred_labels = [p.replace('_NOKG', '') for p in pred_labels] true_labels = [t.replace('_NOKG', '') for t in true_labels] true_labels, pred_labels = filter_kg_labels( true_labels, pred_labels) pred_labels = [p.replace('_', '-') for p in pred_labels] true_labels = [t.replace('_', '-') for t in true_labels] biluo_tags_predicted = get_bio(pred_labels) biluo_tags_true = get_bio(true_labels) if len(biluo_tags_predicted) != len(biluo_tags_true): logger.error( 'The length of the predicted labels is not same as that of true labels..' ) exit() predicted_labels_all.append(biluo_tags_predicted) true_labels_all.append(biluo_tags_true) if final: with open(f'{args.output_file_prefix}_predictions.txt', 'a') as p, \ open(f'{args.output_file_prefix}_gold.txt', 'a') as g: p.write('\n'.join([' '.join(l) for l in predicted_labels_all])) g.write('\n'.join([' '.join(l) for l in true_labels_all])) return dict( f1=seqeval.metrics.f1_score(true_labels_all, predicted_labels_all), precision=seqeval.metrics.precision_score(true_labels_all, predicted_labels_all), recall=seqeval.metrics.recall_score(true_labels_all, predicted_labels_all), f1_span=f1_score_span(true_labels_all, predicted_labels_all), precision_span=precision_score_span(true_labels_all, predicted_labels_all), recall_span=recall_score_span(true_labels_all, predicted_labels_all), ) # Training phase. logger.info("Start training.") instances = read_dataset(args.train_path) input_ids = torch.LongTensor([ins[0] for ins in instances]) label_ids = torch.LongTensor([ins[1] for ins in instances]) mask_ids = torch.LongTensor([ins[2] for ins in instances]) pos_ids = torch.LongTensor([ins[3] for ins in instances]) vm_ids = torch.BoolTensor([ins[4] for ins in instances]) tag_ids = torch.LongTensor([ins[5] for ins in instances]) segment_ids = torch.LongTensor([ins[6] for ins in instances]) instances_num = input_ids.size(0) batch_size = args.batch_size train_steps = int(instances_num * args.epochs_num / batch_size) + 1 train_batcher = Batcher(batch_size, input_ids, label_ids, mask_ids, pos_ids, vm_ids, tag_ids, segment_ids) logger.info(f"Batch size:{batch_size}") logger.info(f"The number of training instances:{instances_num}") param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0 }] optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup, t_total=train_steps) scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs_num) total_loss = 0. best_f1 = 0.0 # Dry evaluate # evaluate(args, True) for epoch in range(1, args.epochs_num + 1): model.train() for i, (input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch, segment_ids_batch) in enumerate(train_batcher): model.zero_grad() input_ids_batch = input_ids_batch.to(device) label_ids_batch = label_ids_batch.to(device) mask_ids_batch = mask_ids_batch.to(device) pos_ids_batch = pos_ids_batch.to(device) tag_ids_batch = tag_ids_batch.to(device) vm_ids_batch = vm_ids_batch.long().to(device) segment_ids_batch = segment_ids_batch.long().to(device) loss = model.score(input_ids_batch, segment_ids_batch, mask_ids_batch, label_ids_batch, pos_ids_batch, vm_ids_batch, use_kg=args.use_kg) if torch.cuda.device_count() > 1: loss = torch.mean(loss) total_loss += loss.item() if (i + 1) % args.report_steps == 0: logger.info( "Epoch id: {}, Training steps: {}, Avg loss: {:.3f}". format(epoch, i + 1, total_loss / args.report_steps)) total_loss = 0. loss.backward() optimizer.step() if args.schedule_lr: # Update learning rate scheduler.step() # Evaluation phase. logger.info("Start evaluate on dev dataset.") results = evaluate(args, False) logger.info(results) logger.info("Start evaluation on test dataset.") results_test = evaluate(args, True) logger.info(results_test) if results['f1'] > best_f1: best_f1 = results['f1'] save_model(model, args.output_model_path) save_encoder(args, encoder, suffix=args.suffix_file_encoder) else: continue # Evaluation phase. logger.info("Final evaluation on test dataset.") if torch.cuda.device_count() > 1: model.module.load_state_dict(torch.load(args.output_model_path)) else: model.load_state_dict(torch.load(args.output_model_path)) results_final = evaluate(args, True, final=True) logger.info(results_final)