def create_and_check_for_token_classification( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_labels = self.num_labels model = BertForTokenClassification(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_bert_for_token_classification( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): config.num_labels = self.num_labels model = BertForTokenClassification(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) self.parent.assertListEqual( list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.check_loss_output(result)
params.seed = args.seed test_sentences = load_test_sentences(args.bert_model_dir, args.test_file) # Specify the test set size params.test_size = len(test_sentences) params.eval_steps = params.test_size // params.batch_size # Define the model config_path = os.path.join(args.bert_model_dir, 'config.json') config = BertConfig.from_json_file(config_path) #update config with num_labels config.update({"num_labels": 2}) model = BertForTokenClassification(config) #model = BertForTokenClassification(config, num_labels=2) model.to(params.device) # Reload weights from the saved file utils.load_checkpoint( os.path.join(args.model_dir, args.restore_file + '.pth.tar'), model) if args.fp16: model.half() if params.n_gpu > 1 and args.multi_gpu: model = torch.nn.DataParallel(model) predict(model=model, data_iterator=yield_data_batch(test_sentences, params), params=params, sentences_file=args.test_file)
def main(): parser = argparse.ArgumentParser() ## Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the training files for the CoNLL-2003 NER task." ) parser.add_argument("--model_type", default=None, type=str, required=True, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) parser.add_argument( "--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) # parser.add_argument("--output_dir", default=None, type=str, required=True, # help="The output directory where the model predictions and checkpoints will be written.") ## Other parameters parser.add_argument( "--labels", default="", type=str, help= "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used." ) parser.add_argument( "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name") parser.add_argument( "--tokenizer_name", default="", type=str, help="Pretrained tokenizer name or path if not the same as model_name") parser.add_argument( "--cache_dir", default="", type=str, help= "Where do you want to store the pre-trained models downloaded from s3") parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded.") parser.add_argument("--do_train", action="store_true", help="Whether to run training.") parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") parser.add_argument( "--evaluate_during_training", action="store_true", help="Whether to run evaluation during training at each logging step.") parser.add_argument( "--test_during_training", action="store_true", help="Whether to run test during training at each logging step.") parser.add_argument( "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.") parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.") parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--max_steps", default=-1, type=int, help= "If > 0: set total number of training steps to perform. Override num_train_epochs." ) parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help="Ratio of linear warmup steps over all training steps") parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.") parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.") parser.add_argument( "--eval_all_checkpoints", action="store_true", help= "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number" ) parser.add_argument( "--test_all_checkpoints", action="store_true", help= "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number" ) parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") parser.add_argument("--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory") parser.add_argument( "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets") parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") parser.add_argument( "--fp16", action="store_true", help= "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit" ) parser.add_argument( "--fp16_opt_level", type=str, default="O1", help= "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.") parser.add_argument("--server_port", type=str, default="", help="For distant debugging.") parser.add_argument("--random_start", action="store_true") parser.add_argument("--yago_reference", action="store_true") parser.add_argument("--max_reference_num", type=int, default=10, help="Number of yago types considered as references") parser.add_argument("--additional_output_tag", type=str, default="", help="Additional tag to distinguish from other models") parser.add_argument( "--do_significant_check", action="store_true", help= "Whether to check if the influence of reference embedding is significant" ) args = parser.parse_args() DEFAULT_DATA_REPO = '/work/smt2/qfeng/Project/huggingface/datasets/' DEFAULT_CACHE_REPO = '/work/smt2/qfeng/Project/huggingface/pretrain/' DEFAULT_OUTPUT_REPO = '/work/smt2/qfeng/Project/huggingface/models/' if '/' not in args.data_dir: args.data_dir = DEFAULT_DATA_REPO + args.data_dir if '/' not in args.cache_dir: if args.cache_dir == "": args.cache_dir = DEFAULT_CACHE_REPO + args.model_name_or_path[ len('bert-'):] else: args.cache_dir = DEFAULT_CACHE_REPO + args.cache_dir if args.labels == "": if os.path.exists(os.path.join(args.data_dir, 'labels.txt')): args.labels = os.path.join(args.data_dir, 'labels.txt') else: raise ValueError("Invalid or missing labels file!") if '-uncased' in args.model_name_or_path or '_uncased' in args.model_name_or_path: args.do_lower_case = True elif '-cased' in args.model_name_or_path or '_cased' in args.model_name_or_path: args.do_lower_case = False # name the output diretory according to the used model, time tag, usage of yago reference output_dir = args.model_name_or_path.split('/')[-1] if args.yago_reference: output_dir += "_yagoref" if args.additional_output_tag != "": output_dir += '_' + args.additional_output_tag else: now_time = datetime.datetime.now() output_dir += '_' + '-'.join( str(i) for i in list(now_time.timetuple()[1:3])) # 'month-date' args.output_dir = DEFAULT_OUTPUT_REPO + output_dir logger.info("output model to file {}".format(output_dir)) if args.tokenizer_name == "": args.tokenizer_name = 'bert-base-uncased' if args.do_lower_case else 'bert-base-cased' if args.yago_reference: REFERENCE_SIZE = 959 # if args.yago_reference: # with open('/work/smt3/wwang/TAC2019/qihui_data/yago/YagoReference.pickle', 'rb') as ref_pickle: #TODO: # ref_dict = pickle.load(ref_pickle) if os.path.exists(args.output_dir) and os.listdir( args.output_dir ) and args.do_train and not args.overwrite_output_dir: raise ValueError( "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome." .format(args.output_dir)) # Setup distant debugging if needed if args.server_ip and args.server_port: # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script import ptvsd print("Waiting for debugger attach") ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.wait_for_attach() # Setup CUDA, GPU & distributed training if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = torch.cuda.device_count() else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend="nccl") args.n_gpu = 1 args.device = device # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) args.logging_steps = int(args.logging_steps / args.n_gpu) args.save_steps = int(args.save_steps / args.n_gpu) # Set seed set_seed(args) # Prepare CONLL-2003 task labels = get_labels(args.labels) num_labels = len(labels) # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later pad_token_label_id = CrossEntropyLoss().ignore_index # Load pretrained model and tokenizer if args.local_rank not in [-1, 0]: torch.distributed.barrier( ) # Make sure only the first process in distributed training will download model & vocab args.model_type = args.model_type.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] # bertconfig = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, # num_labels=num_labels, # cache_dir=args.cache_dir if args.cache_dir else None) tokenizer = tokenizer_class.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None) if not args.yago_reference: config = config_class.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, cache_dir=args.cache_dir if args.cache_dir else None) """ Test the ablation of pretrained model """ if args.random_start: model = BertForTokenClassification(config) else: model = model_class.from_pretrained( args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, cache_dir=args.cache_dir if args.cache_dir else None) else: # config = YagoRefBertConfig(bertconfig.__dict__, reference_size=REFERENCE_SIZE, # num_labels=num_labels, # cache_dir=args.cache_dir if args.cache_dir else None # ) config = YagoRefBertConfig.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, reference_size=REFERENCE_SIZE, num_labels=num_labels, cache_dir=args.cache_dir if args.cache_dir else None) logger.info("number of labels %d", config.num_labels) logger.info("vocab size: %d", config.vocab_size) model = YagoRefBertForTokenClassification.from_pretrained( args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, cache_dir=args.cache_dir if args.cache_dir else None) # model = model_class.from_pretrained(args.model_name_or_path, # from_tf=bool(".ckpt" in args.model_name_or_path), # config=config, # cache_dir=args.cache_dir if args.cache_dir else None) if args.local_rank == 0: torch.distributed.barrier( ) # Make sure only the first process in distributed training will download model & vocab model.to(args.device) logger.info("Training/evaluation parameters %s", args) if args.overwrite_cache: load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode="train") load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode="dev") load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode="test") args.overwrite_cache = False # Training if args.do_train: train_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode="train") global_step, tr_loss = train(args, train_dataset, model, tokenizer, labels, pad_token_label_id) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): # Create output directory if needed if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: os.makedirs(args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir) # Save a trained model, configuration and tokenizer using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` model_to_save = model.module if hasattr( model, "module") else model # Take care of distributed/parallel training model_to_save.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) # Good practice: save your training arguments together with the trained model torch.save(args, os.path.join(args.output_dir, "training_args.bin")) # Evaluation results = {} if args.do_eval and args.local_rank in [-1, 0]: # tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) checkpoints = [args.output_dir] if args.eval_all_checkpoints: checkpoints = list( os.path.dirname(c) for c in sorted( glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))) logging.getLogger("pytorch_transformers.modeling_utils").setLevel( logging.WARN) # Reduce logging logger.info("Evaluate the following checkpoints: %s", checkpoints) for checkpoint in checkpoints: global_step = checkpoint.split("-")[-1] if len( checkpoint.split("-")[-1]) > 2 else checkpoint if args.yago_reference: model = YagoRefBertForTokenClassification.from_pretrained( checkpoint) else: model = model_class.from_pretrained(checkpoint) model.to(args.device) result, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev", prefix=global_step) if global_step: result = { "{}_{}".format(global_step, k): v for k, v in result.items() } results.update(result) output_eval_file = os.path.join(args.output_dir, "eval_results.txt") with open(output_eval_file, "w") as writer: for key in sorted(results.keys()): writer.write("{} = {}\n".format(key, str(results[key]))) if args.do_predict and args.local_rank in [-1, 0]: # tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) checkpoints = [args.output_dir] if args.test_all_checkpoints: checkpoints = list( os.path.dirname(c) for c in sorted( glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))) logging.getLogger("pytorch_transformers.modeling_utils").setLevel( logging.WARN) # Reduce logging logger.info("Test the following checkpoints: %s", checkpoints) for checkpoint in checkpoints: global_step = checkpoint.split("-")[-1] if len( checkpoint.split("-")[-1]) > 2 else checkpoint if args.yago_reference: model = YagoRefBertForTokenClassification.from_pretrained( checkpoint) else: model = model_class.from_pretrained(checkpoint) model.to(args.device) result, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test", prefix=global_step) if global_step: result = { "{}_{}".format(global_step, k): v for k, v in result.items() } results.update(result) output_eval_file = os.path.join(args.output_dir, "test_results.txt") with open(output_eval_file, "w") as writer: for key in sorted(results.keys()): writer.write("{} = {}\n".format(key, str(results[key]))) if args.yago_reference: model = YagoRefBertForTokenClassification.from_pretrained( args.output_dir, config=config) else: model = model_class.from_pretrained(args.output_dir) model.to(args.device) result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test") # Save predictions output_test_predictions_file = os.path.join(args.output_dir, "test_predictions.txt") with open(output_test_predictions_file, "w") as writer: with open(os.path.join(args.data_dir, "test.txt"), "r") as f: example_id = 0 for line in f: if line.startswith( "-DOCSTART-") or line == "" or line == "\n": writer.write(line) if not predictions[example_id]: example_id += 1 elif predictions[example_id]: output_line = line.split( )[0] + " " + predictions[example_id].pop(0) + "\n" writer.write(output_line) else: logger.warning( "Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]) # Temporary code: check whether the values of reference_embedding are significant if args.do_significant_check and args.yago_reference: model = YagoRefBertForTokenClassification.from_pretrained( args.output_dir, config=config) bertconfig = config_class.from_pretrained( "/work/smt2/qfeng/Project/huggingface/models/base-cased_1-9/", num_labels=num_labels, cache_dir=args.cache_dir if args.cache_dir else None) model_noyago = BertForTokenClassification.from_pretrained( "/work/smt2/qfeng/Project/huggingface/models/base-cased_1-9/", from_tf=bool(".ckpt" in args.model_name_or_path), config=bertconfig, cache_dir=args.cache_dir if args.cache_dir else None) model.to(args.device) model.eval() model_noyago.to(args.device) model_noyago.eval() # logger.info(model.bert.embeddings.word_embeddings.weight.size()) with open( '/work/smt3/wwang/TAC2019/qihui_data/yago/YagoReference_prune{}.pickle' .format("" if args.do_lower_case else "_cased"), 'rb') as ref_pickle: #TODO: ref_dict = pickle.load(ref_pickle) cos_sim = [] cos_sim_ww = [] ref_vec_norm = [] cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6) for id in range(config.vocab_size): if id in ref_dict: reference_ids = model.bert.embeddings.ref_ids[id].to(device) reference_weights = model.bert.embeddings.ref_weights[id].to( device) # logger.info(reference_ids.size()) # logger.info(reference_weights.size()) reference_embedding = torch.sum( model.bert.reference_embeddings(reference_ids) * torch.unsqueeze(reference_weights, dim=-1), dim=-2) word_embedding = model.bert.embeddings.word_embeddings( torch.tensor(id, dtype=torch.long, device=args.device)) # ref_id_list = torch.tensor(list(ref_dict[id].keys()), dtype=torch.long, device=args.device) # ref_id_weight = torch.tensor(list(ref_dict[id].values()), dtype=torch.float, device=args.device) # reference_embeddings = model.bert.embeddings.reference_embeddings(ref_id_list)*torch.unsqueeze(ref_id_weight, dim=-1) # # logger.info(reference_embeddings.size()) # reference_embedding = torch.sum(reference_embeddings,dim=-2) vec_norm = torch.norm(reference_embedding) ref_vec_norm.append(vec_norm) noyago_word_embedding = model_noyago.bert.embeddings.word_embeddings( torch.tensor(id, dtype=torch.long, device=args.device)) # logger.info(vec_norm/torch.norm(word_embedding)) cos_sim.append(cos(word_embedding, reference_embedding)) cos_sim_ww.append(cos(word_embedding, noyago_word_embedding)) logger.info(cos(word_embedding, reference_embedding)) assert (word_embedding.size() == reference_embedding.size()) # word_embedding_norm = torch.norm(model.bert.embeddings.word_embeddings.weight, p=2, dim=1) # reference_embedding_norm = torch.norm(model.bert.embeddings.reference_embeddings.weight, p=2, dim=1) avg_sim = sum(cos_sim) / len(cos_sim) avg_sim_ww = sum(cos_sim_ww) / len(cos_sim_ww) avg_ratio = sum(ref_vec_norm) / len(ref_vec_norm) logger.info(avg_sim) logger.info(avg_sim_ww) logger.info(avg_ratio) return results
genes, labels = read_non_split_file( '/home/brian/Downloads/all_samples_6-mer_train.txt') seq_ids, masks, labels = tokenize_and_pad_samples(genes, labels) print(seq_ids[0]) print(len(seq_ids)) print("Finished making data") batch_size = 1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = BertForTokenClassification( BertConfig.from_json_file( '/home/brian/attentive_splice/bert_configuration_all_hex.json')) model.resize_token_embeddings(4099) model.to(device) optimizer = Adam(model.parameters(), lr=1e-3) #lr=3e-5) class_weights = torch.tensor(np.array([1.0, 165.0])).float().cuda() loss = CrossEntropyLoss(weight=class_weights) last_i = 0 def load_model_from_saved(): with open('/home/brian/bert_last_i.txt', 'r') as last_i_file: i = last_i_file.read() last_i = int(i) model.load_state_dict(torch.load("/home/brian/bert_splice_weights.pt")) def save_weights(): print("Saving weights")
class TorchBertSequenceTagger(TorchModel): """BERT-based model on PyTorch for text tagging. It predicts a label for every token (not subtoken) in the text. You can use it for sequence labeling tasks, such as morphological tagging or named entity recognition. Args: n_tags: number of distinct tags pretrained_bert: pretrained Bert checkpoint path or key title (e.g. "bert-base-uncased") return_probas: set this to `True` if you need the probabilities instead of raw answers bert_config_file: path to Bert configuration file, or None, if `pretrained_bert` is a string name attention_probs_keep_prob: keep_prob for Bert self-attention layers hidden_keep_prob: keep_prob for Bert hidden layers optimizer: optimizer name from `torch.optim` optimizer_parameters: dictionary with optimizer's parameters, e.g. {'lr': 0.1, 'weight_decay': 0.001, 'momentum': 0.9} learning_rate_drop_patience: how many validations with no improvements to wait learning_rate_drop_div: the divider of the learning rate after `learning_rate_drop_patience` unsuccessful validations load_before_drop: whether to load best model before dropping learning rate or not clip_norm: clip gradients by norm min_learning_rate: min value of learning rate if learning rate decay is used """ def __init__(self, n_tags: int, pretrained_bert: str, bert_config_file: Optional[str] = None, return_probas: bool = False, attention_probs_keep_prob: Optional[float] = None, hidden_keep_prob: Optional[float] = None, optimizer: str = "AdamW", optimizer_parameters: dict = {"lr": 1e-3, "weight_decay": 1e-6}, learning_rate_drop_patience: int = 20, learning_rate_drop_div: float = 2.0, load_before_drop: bool = True, clip_norm: Optional[float] = None, min_learning_rate: float = 1e-07, **kwargs) -> None: self.n_classes = n_tags self.return_probas = return_probas self.attention_probs_keep_prob = attention_probs_keep_prob self.hidden_keep_prob = hidden_keep_prob self.clip_norm = clip_norm self.pretrained_bert = pretrained_bert self.bert_config_file = bert_config_file super().__init__(optimizer=optimizer, optimizer_parameters=optimizer_parameters, learning_rate_drop_patience=learning_rate_drop_patience, learning_rate_drop_div=learning_rate_drop_div, load_before_drop=load_before_drop, min_learning_rate=min_learning_rate, **kwargs) def train_on_batch(self, input_ids: Union[List[List[int]], np.ndarray], input_masks: Union[List[List[int]], np.ndarray], y_masks: Union[List[List[int]], np.ndarray], y: List[List[int]], *args, **kwargs) -> Dict[str, float]: """ Args: input_ids: batch of indices of subwords input_masks: batch of masks which determine what should be attended args: arguments passed to _build_feed_dict and corresponding to additional input and output tensors of the derived class. kwargs: keyword arguments passed to _build_feed_dict and corresponding to additional input and output tensors of the derived class. Returns: dict with fields 'loss', 'head_learning_rate', and 'bert_learning_rate' """ b_input_ids = torch.from_numpy(input_ids).to(self.device) b_input_masks = torch.from_numpy(input_masks).to(self.device) subtoken_labels = [token_labels_to_subtoken_labels(y_el, y_mask, input_mask) for y_el, y_mask, input_mask in zip(y, y_masks, input_masks)] b_labels = torch.from_numpy(np.array(subtoken_labels)).to(torch.int64).to(self.device) self.optimizer.zero_grad() loss, logits = self.model(input_ids=b_input_ids, token_type_ids=None, attention_mask=b_input_masks, labels=b_labels) loss.backward() # Clip the norm of the gradients to 1.0. # This is to help prevent the "exploding gradients" problem. if self.clip_norm: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_norm) self.optimizer.step() if self.lr_scheduler is not None: self.lr_scheduler.step() return {'loss': loss.item()} def __call__(self, input_ids: Union[List[List[int]], np.ndarray], input_masks: Union[List[List[int]], np.ndarray], y_masks: Union[List[List[int]], np.ndarray]) -> Union[List[List[int]], List[np.ndarray]]: """ Predicts tag indices for a given subword tokens batch Args: input_ids: indices of the subwords input_masks: mask that determines where to attend and where not to y_masks: mask which determines the first subword units in the the word Returns: Label indices or class probabilities for each token (not subtoken) """ b_input_ids = torch.from_numpy(input_ids).to(self.device) b_input_masks = torch.from_numpy(input_masks).to(self.device) with torch.no_grad(): # Forward pass, calculate logit predictions logits = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_masks) # Move logits and labels to CPU and to numpy arrays logits = token_from_subtoken(logits[0].detach().cpu(), torch.from_numpy(y_masks)) if self.return_probas: pred = torch.nn.functional.softmax(logits, dim=-1) pred = pred.detach().cpu().numpy() else: logits = logits.detach().cpu().numpy() pred = np.argmax(logits, axis=-1) seq_lengths = np.sum(y_masks, axis=1) pred = [p[:l] for l, p in zip(seq_lengths, pred)] return pred @overrides def load(self, fname=None): if fname is not None: self.load_path = fname if self.pretrained_bert and not Path(self.pretrained_bert).is_file(): self.model = BertForTokenClassification.from_pretrained( self.pretrained_bert, num_labels=self.n_classes, output_attentions=False, output_hidden_states=False) elif self.bert_config_file and Path(self.bert_config_file).is_file(): self.bert_config = BertConfig.from_json_file(str(expand_path(self.bert_config_file))) if self.attention_probs_keep_prob is not None: self.bert_config.attention_probs_dropout_prob = 1.0 - self.attention_probs_keep_prob if self.hidden_keep_prob is not None: self.bert_config.hidden_dropout_prob = 1.0 - self.hidden_keep_prob self.model = BertForTokenClassification(config=self.bert_config) else: raise ConfigError("No pre-trained BERT model is given.") self.model.to(self.device) self.optimizer = getattr(torch.optim, self.optimizer_name)( self.model.parameters(), **self.optimizer_parameters) if self.lr_scheduler_name is not None: self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)( self.optimizer, **self.lr_scheduler_parameters) if self.load_path: log.info(f"Load path {self.load_path} is given.") if isinstance(self.load_path, Path) and not self.load_path.parent.is_dir(): raise ConfigError("Provided load path is incorrect!") weights_path = Path(self.load_path.resolve()) weights_path = weights_path.with_suffix(f".pth.tar") if weights_path.exists(): log.info(f"Load path {weights_path} exists.") log.info(f"Initializing `{self.__class__.__name__}` from saved.") # now load the weights, optimizer from saved log.info(f"Loading weights from {weights_path}.") checkpoint = torch.load(weights_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.epochs_done = checkpoint.get("epochs_done", 0) else: log.info(f"Init from scratch. Load path {weights_path} does not exist.")