def create_and_check_for_token_classification(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     config.num_labels = self.num_labels
     model = BertForTokenClassification(config=config)
     model.to(torch_device)
     model.eval()
     result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
     self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
Exemplo n.º 2
0
 def create_and_check_bert_for_token_classification(
         self, config, input_ids, token_type_ids, input_mask,
         sequence_labels, token_labels, choice_labels):
     config.num_labels = self.num_labels
     model = BertForTokenClassification(config=config)
     model.to(torch_device)
     model.eval()
     result = model(input_ids,
                    attention_mask=input_mask,
                    token_type_ids=token_type_ids,
                    labels=token_labels)
     self.parent.assertListEqual(
         list(result["logits"].size()),
         [self.batch_size, self.seq_length, self.num_labels])
     self.check_loss_output(result)
    params.seed = args.seed

    test_sentences = load_test_sentences(args.bert_model_dir, args.test_file)

    # Specify the test set size
    params.test_size = len(test_sentences)
    params.eval_steps = params.test_size // params.batch_size

    # Define the model
    config_path = os.path.join(args.bert_model_dir, 'config.json')
    config = BertConfig.from_json_file(config_path)

    #update config with num_labels
    config.update({"num_labels": 2})
    model = BertForTokenClassification(config)
    #model = BertForTokenClassification(config, num_labels=2)

    model.to(params.device)
    # Reload weights from the saved file
    utils.load_checkpoint(
        os.path.join(args.model_dir, args.restore_file + '.pth.tar'), model)
    if args.fp16:
        model.half()
    if params.n_gpu > 1 and args.multi_gpu:
        model = torch.nn.DataParallel(model)

    predict(model=model,
            data_iterator=yield_data_batch(test_sentences, params),
            params=params,
            sentences_file=args.test_file)
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the training files for the CoNLL-2003 NER task."
    )
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    # parser.add_argument("--output_dir", default=None, type=str, required=True,
    #                     help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument(
        "--labels",
        default="",
        type=str,
        help=
        "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."
    )
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument("--do_train",
                        action="store_true",
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action="store_true",
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_predict",
                        action="store_true",
                        help="Whether to run predictions on the test set.")
    parser.add_argument(
        "--evaluate_during_training",
        action="store_true",
        help="Whether to run evaluation during training at each logging step.")
    parser.add_argument(
        "--test_during_training",
        action="store_true",
        help="Whether to run test during training at each logging step.")
    parser.add_argument(
        "--do_lower_case",
        action="store_true",
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help="Ratio of linear warmup steps over all training steps")

    parser.add_argument("--logging_steps",
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument("--save_steps",
                        type=int,
                        default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
    )
    parser.add_argument(
        "--test_all_checkpoints",
        action="store_true",
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
    )
    parser.add_argument("--no_cuda",
                        action="store_true",
                        help="Avoid using CUDA when available")
    parser.add_argument("--overwrite_output_dir",
                        action="store_true",
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        "--overwrite_cache",
        action="store_true",
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument("--seed",
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument("--server_ip",
                        type=str,
                        default="",
                        help="For distant debugging.")
    parser.add_argument("--server_port",
                        type=str,
                        default="",
                        help="For distant debugging.")
    parser.add_argument("--random_start", action="store_true")
    parser.add_argument("--yago_reference", action="store_true")
    parser.add_argument("--max_reference_num",
                        type=int,
                        default=10,
                        help="Number of yago types considered as references")
    parser.add_argument("--additional_output_tag",
                        type=str,
                        default="",
                        help="Additional tag to distinguish from other models")
    parser.add_argument(
        "--do_significant_check",
        action="store_true",
        help=
        "Whether to check if the influence of reference embedding is significant"
    )

    args = parser.parse_args()

    DEFAULT_DATA_REPO = '/work/smt2/qfeng/Project/huggingface/datasets/'
    DEFAULT_CACHE_REPO = '/work/smt2/qfeng/Project/huggingface/pretrain/'
    DEFAULT_OUTPUT_REPO = '/work/smt2/qfeng/Project/huggingface/models/'

    if '/' not in args.data_dir:
        args.data_dir = DEFAULT_DATA_REPO + args.data_dir
    if '/' not in args.cache_dir:
        if args.cache_dir == "":
            args.cache_dir = DEFAULT_CACHE_REPO + args.model_name_or_path[
                len('bert-'):]
        else:
            args.cache_dir = DEFAULT_CACHE_REPO + args.cache_dir
    if args.labels == "":
        if os.path.exists(os.path.join(args.data_dir, 'labels.txt')):
            args.labels = os.path.join(args.data_dir, 'labels.txt')
        else:
            raise ValueError("Invalid or missing labels file!")
    if '-uncased' in args.model_name_or_path or '_uncased' in args.model_name_or_path:
        args.do_lower_case = True
    elif '-cased' in args.model_name_or_path or '_cased' in args.model_name_or_path:
        args.do_lower_case = False

    # name the output diretory according to the used model, time tag, usage of yago reference
    output_dir = args.model_name_or_path.split('/')[-1]
    if args.yago_reference:
        output_dir += "_yagoref"
    if args.additional_output_tag != "":
        output_dir += '_' + args.additional_output_tag
    else:
        now_time = datetime.datetime.now()
        output_dir += '_' + '-'.join(
            str(i) for i in list(now_time.timetuple()[1:3]))  # 'month-date'

    args.output_dir = DEFAULT_OUTPUT_REPO + output_dir
    logger.info("output model to file {}".format(output_dir))

    if args.tokenizer_name == "":
        args.tokenizer_name = 'bert-base-uncased' if args.do_lower_case else 'bert-base-cased'

    if args.yago_reference:
        REFERENCE_SIZE = 959

    # if args.yago_reference:
    #     with open('/work/smt3/wwang/TAC2019/qihui_data/yago/YagoReference.pickle', 'rb') as ref_pickle: #TODO:
    #         ref_dict = pickle.load(ref_pickle)

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir
    ) and args.do_train and not args.overwrite_output_dir:
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.output_dir))

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank, device, args.n_gpu, bool(args.local_rank != -1),
        args.fp16)

    args.logging_steps = int(args.logging_steps / args.n_gpu)
    args.save_steps = int(args.save_steps / args.n_gpu)

    # Set seed
    set_seed(args)

    # Prepare CONLL-2003 task
    labels = get_labels(args.labels)
    num_labels = len(labels)
    # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
    pad_token_label_id = CrossEntropyLoss().ignore_index

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    # bertconfig = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
    #                                       num_labels=num_labels,
    #                                       cache_dir=args.cache_dir if args.cache_dir else None)
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None)
    if not args.yago_reference:
        config = config_class.from_pretrained(
            args.config_name if args.config_name else args.model_name_or_path,
            num_labels=num_labels,
            cache_dir=args.cache_dir if args.cache_dir else None)
        """
            Test the ablation of pretrained model
            """
        if args.random_start:
            model = BertForTokenClassification(config)
        else:
            model = model_class.from_pretrained(
                args.model_name_or_path,
                from_tf=bool(".ckpt" in args.model_name_or_path),
                config=config,
                cache_dir=args.cache_dir if args.cache_dir else None)
    else:
        # config = YagoRefBertConfig(bertconfig.__dict__, reference_size=REFERENCE_SIZE,
        #                            num_labels=num_labels,
        #                            cache_dir=args.cache_dir if args.cache_dir else None
        #                            )
        config = YagoRefBertConfig.from_pretrained(
            args.config_name if args.config_name else args.model_name_or_path,
            reference_size=REFERENCE_SIZE,
            num_labels=num_labels,
            cache_dir=args.cache_dir if args.cache_dir else None)
        logger.info("number of labels %d", config.num_labels)
        logger.info("vocab size: %d", config.vocab_size)
        model = YagoRefBertForTokenClassification.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
            cache_dir=args.cache_dir if args.cache_dir else None)

    # model = model_class.from_pretrained(args.model_name_or_path,
    #                                     from_tf=bool(".ckpt" in args.model_name_or_path),
    #                                     config=config,
    #                                     cache_dir=args.cache_dir if args.cache_dir else None)

    if args.local_rank == 0:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    if args.overwrite_cache:
        load_and_cache_examples(args,
                                tokenizer,
                                labels,
                                pad_token_label_id,
                                mode="train")
        load_and_cache_examples(args,
                                tokenizer,
                                labels,
                                pad_token_label_id,
                                mode="dev")
        load_and_cache_examples(args,
                                tokenizer,
                                labels,
                                pad_token_label_id,
                                mode="test")
        args.overwrite_cache = False

    # Training
    if args.do_train:
        train_dataset = load_and_cache_examples(args,
                                                tokenizer,
                                                labels,
                                                pad_token_label_id,
                                                mode="train")
        global_step, tr_loss = train(args, train_dataset, model, tokenizer,
                                     labels, pad_token_label_id)
        logger.info(" global_step = %s, average loss = %s", global_step,
                    tr_loss)

    # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
    if args.do_train and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = model.module if hasattr(
            model,
            "module") else model  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

    # Evaluation
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        # tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(
                    glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME,
                              recursive=True)))
            logging.getLogger("pytorch_transformers.modeling_utils").setLevel(
                logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if len(
                checkpoint.split("-")[-1]) > 2 else checkpoint
            if args.yago_reference:
                model = YagoRefBertForTokenClassification.from_pretrained(
                    checkpoint)
            else:
                model = model_class.from_pretrained(checkpoint)
            model.to(args.device)
            result, _ = evaluate(args,
                                 model,
                                 tokenizer,
                                 labels,
                                 pad_token_label_id,
                                 mode="dev",
                                 prefix=global_step)
            if global_step:
                result = {
                    "{}_{}".format(global_step, k): v
                    for k, v in result.items()
                }
            results.update(result)
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            for key in sorted(results.keys()):
                writer.write("{} = {}\n".format(key, str(results[key])))

    if args.do_predict and args.local_rank in [-1, 0]:
        # tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
        checkpoints = [args.output_dir]
        if args.test_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(
                    glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME,
                              recursive=True)))
            logging.getLogger("pytorch_transformers.modeling_utils").setLevel(
                logging.WARN)  # Reduce logging
        logger.info("Test the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if len(
                checkpoint.split("-")[-1]) > 2 else checkpoint
            if args.yago_reference:
                model = YagoRefBertForTokenClassification.from_pretrained(
                    checkpoint)
            else:
                model = model_class.from_pretrained(checkpoint)
            model.to(args.device)
            result, _ = evaluate(args,
                                 model,
                                 tokenizer,
                                 labels,
                                 pad_token_label_id,
                                 mode="test",
                                 prefix=global_step)
            if global_step:
                result = {
                    "{}_{}".format(global_step, k): v
                    for k, v in result.items()
                }
            results.update(result)
        output_eval_file = os.path.join(args.output_dir, "test_results.txt")
        with open(output_eval_file, "w") as writer:
            for key in sorted(results.keys()):
                writer.write("{} = {}\n".format(key, str(results[key])))

        if args.yago_reference:
            model = YagoRefBertForTokenClassification.from_pretrained(
                args.output_dir, config=config)
        else:
            model = model_class.from_pretrained(args.output_dir)
        model.to(args.device)
        result, predictions = evaluate(args,
                                       model,
                                       tokenizer,
                                       labels,
                                       pad_token_label_id,
                                       mode="test")
        # Save predictions
        output_test_predictions_file = os.path.join(args.output_dir,
                                                    "test_predictions.txt")
        with open(output_test_predictions_file, "w") as writer:
            with open(os.path.join(args.data_dir, "test.txt"), "r") as f:
                example_id = 0
                for line in f:
                    if line.startswith(
                            "-DOCSTART-") or line == "" or line == "\n":
                        writer.write(line)
                        if not predictions[example_id]:
                            example_id += 1
                    elif predictions[example_id]:
                        output_line = line.split(
                        )[0] + " " + predictions[example_id].pop(0) + "\n"
                        writer.write(output_line)
                    else:
                        logger.warning(
                            "Maximum sequence length exceeded: No prediction for '%s'.",
                            line.split()[0])

    # Temporary code: check whether the values of reference_embedding are significant
    if args.do_significant_check and args.yago_reference:
        model = YagoRefBertForTokenClassification.from_pretrained(
            args.output_dir, config=config)
        bertconfig = config_class.from_pretrained(
            "/work/smt2/qfeng/Project/huggingface/models/base-cased_1-9/",
            num_labels=num_labels,
            cache_dir=args.cache_dir if args.cache_dir else None)
        model_noyago = BertForTokenClassification.from_pretrained(
            "/work/smt2/qfeng/Project/huggingface/models/base-cased_1-9/",
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=bertconfig,
            cache_dir=args.cache_dir if args.cache_dir else None)
        model.to(args.device)
        model.eval()
        model_noyago.to(args.device)
        model_noyago.eval()
        # logger.info(model.bert.embeddings.word_embeddings.weight.size())
        with open(
                '/work/smt3/wwang/TAC2019/qihui_data/yago/YagoReference_prune{}.pickle'
                .format("" if args.do_lower_case else "_cased"),
                'rb') as ref_pickle:  #TODO:
            ref_dict = pickle.load(ref_pickle)
        cos_sim = []
        cos_sim_ww = []
        ref_vec_norm = []
        cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
        for id in range(config.vocab_size):
            if id in ref_dict:
                reference_ids = model.bert.embeddings.ref_ids[id].to(device)
                reference_weights = model.bert.embeddings.ref_weights[id].to(
                    device)
                # logger.info(reference_ids.size())
                # logger.info(reference_weights.size())
                reference_embedding = torch.sum(
                    model.bert.reference_embeddings(reference_ids) *
                    torch.unsqueeze(reference_weights, dim=-1),
                    dim=-2)
                word_embedding = model.bert.embeddings.word_embeddings(
                    torch.tensor(id, dtype=torch.long, device=args.device))
                # ref_id_list = torch.tensor(list(ref_dict[id].keys()), dtype=torch.long, device=args.device)
                # ref_id_weight = torch.tensor(list(ref_dict[id].values()), dtype=torch.float, device=args.device)
                # reference_embeddings = model.bert.embeddings.reference_embeddings(ref_id_list)*torch.unsqueeze(ref_id_weight, dim=-1)
                # # logger.info(reference_embeddings.size())
                # reference_embedding = torch.sum(reference_embeddings,dim=-2)
                vec_norm = torch.norm(reference_embedding)
                ref_vec_norm.append(vec_norm)
                noyago_word_embedding = model_noyago.bert.embeddings.word_embeddings(
                    torch.tensor(id, dtype=torch.long, device=args.device))
                # logger.info(vec_norm/torch.norm(word_embedding))
                cos_sim.append(cos(word_embedding, reference_embedding))
                cos_sim_ww.append(cos(word_embedding, noyago_word_embedding))
                logger.info(cos(word_embedding, reference_embedding))
                assert (word_embedding.size() == reference_embedding.size())
        # word_embedding_norm = torch.norm(model.bert.embeddings.word_embeddings.weight, p=2, dim=1)
        # reference_embedding_norm = torch.norm(model.bert.embeddings.reference_embeddings.weight, p=2, dim=1)
        avg_sim = sum(cos_sim) / len(cos_sim)
        avg_sim_ww = sum(cos_sim_ww) / len(cos_sim_ww)
        avg_ratio = sum(ref_vec_norm) / len(ref_vec_norm)
        logger.info(avg_sim)
        logger.info(avg_sim_ww)
        logger.info(avg_ratio)
    return results
genes, labels = read_non_split_file(
    '/home/brian/Downloads/all_samples_6-mer_train.txt')
seq_ids, masks, labels = tokenize_and_pad_samples(genes, labels)
print(seq_ids[0])
print(len(seq_ids))
print("Finished making data")

batch_size = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = BertForTokenClassification(
    BertConfig.from_json_file(
        '/home/brian/attentive_splice/bert_configuration_all_hex.json'))
model.resize_token_embeddings(4099)
model.to(device)
optimizer = Adam(model.parameters(), lr=1e-3)  #lr=3e-5)
class_weights = torch.tensor(np.array([1.0, 165.0])).float().cuda()
loss = CrossEntropyLoss(weight=class_weights)
last_i = 0


def load_model_from_saved():
    with open('/home/brian/bert_last_i.txt', 'r') as last_i_file:
        i = last_i_file.read()
        last_i = int(i)
        model.load_state_dict(torch.load("/home/brian/bert_splice_weights.pt"))


def save_weights():
    print("Saving weights")
class TorchBertSequenceTagger(TorchModel):
    """BERT-based model on PyTorch for text tagging. It predicts a label for every token (not subtoken) in the text.
    You can use it for sequence labeling tasks, such as morphological tagging or named entity recognition.

    Args:
        n_tags: number of distinct tags
        pretrained_bert: pretrained Bert checkpoint path or key title (e.g. "bert-base-uncased")
        return_probas: set this to `True` if you need the probabilities instead of raw answers
        bert_config_file: path to Bert configuration file, or None, if `pretrained_bert` is a string name
        attention_probs_keep_prob: keep_prob for Bert self-attention layers
        hidden_keep_prob: keep_prob for Bert hidden layers
        optimizer: optimizer name from `torch.optim`
        optimizer_parameters: dictionary with optimizer's parameters,
                              e.g. {'lr': 0.1, 'weight_decay': 0.001, 'momentum': 0.9}
        learning_rate_drop_patience: how many validations with no improvements to wait
        learning_rate_drop_div: the divider of the learning rate after `learning_rate_drop_patience` unsuccessful
            validations
        load_before_drop: whether to load best model before dropping learning rate or not
        clip_norm: clip gradients by norm
        min_learning_rate: min value of learning rate if learning rate decay is used
    """

    def __init__(self,
                 n_tags: int,
                 pretrained_bert: str,
                 bert_config_file: Optional[str] = None,
                 return_probas: bool = False,
                 attention_probs_keep_prob: Optional[float] = None,
                 hidden_keep_prob: Optional[float] = None,
                 optimizer: str = "AdamW",
                 optimizer_parameters: dict = {"lr": 1e-3, "weight_decay": 1e-6},
                 learning_rate_drop_patience: int = 20,
                 learning_rate_drop_div: float = 2.0,
                 load_before_drop: bool = True,
                 clip_norm: Optional[float] = None,
                 min_learning_rate: float = 1e-07,
                 **kwargs) -> None:

        self.n_classes = n_tags
        self.return_probas = return_probas
        self.attention_probs_keep_prob = attention_probs_keep_prob
        self.hidden_keep_prob = hidden_keep_prob
        self.clip_norm = clip_norm

        self.pretrained_bert = pretrained_bert
        self.bert_config_file = bert_config_file

        super().__init__(optimizer=optimizer,
                         optimizer_parameters=optimizer_parameters,
                         learning_rate_drop_patience=learning_rate_drop_patience,
                         learning_rate_drop_div=learning_rate_drop_div,
                         load_before_drop=load_before_drop,
                         min_learning_rate=min_learning_rate,
                         **kwargs)

    def train_on_batch(self,
                       input_ids: Union[List[List[int]], np.ndarray],
                       input_masks: Union[List[List[int]], np.ndarray],
                       y_masks: Union[List[List[int]], np.ndarray],
                       y: List[List[int]],
                       *args, **kwargs) -> Dict[str, float]:
        """

        Args:
            input_ids: batch of indices of subwords
            input_masks: batch of masks which determine what should be attended
            args: arguments passed  to _build_feed_dict
                and corresponding to additional input
                and output tensors of the derived class.
            kwargs: keyword arguments passed to _build_feed_dict
                and corresponding to additional input
                and output tensors of the derived class.

        Returns:
            dict with fields 'loss', 'head_learning_rate', and 'bert_learning_rate'
        """
        b_input_ids = torch.from_numpy(input_ids).to(self.device)
        b_input_masks = torch.from_numpy(input_masks).to(self.device)
        subtoken_labels = [token_labels_to_subtoken_labels(y_el, y_mask, input_mask)
                           for y_el, y_mask, input_mask in zip(y, y_masks, input_masks)]
        b_labels = torch.from_numpy(np.array(subtoken_labels)).to(torch.int64).to(self.device)
        self.optimizer.zero_grad()

        loss, logits = self.model(input_ids=b_input_ids, token_type_ids=None, attention_mask=b_input_masks,
                                  labels=b_labels)
        loss.backward()
        # Clip the norm of the gradients to 1.0.
        # This is to help prevent the "exploding gradients" problem.
        if self.clip_norm:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_norm)

        self.optimizer.step()
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return {'loss': loss.item()}

    def __call__(self,
                 input_ids: Union[List[List[int]], np.ndarray],
                 input_masks: Union[List[List[int]], np.ndarray],
                 y_masks: Union[List[List[int]], np.ndarray]) -> Union[List[List[int]], List[np.ndarray]]:
        """ Predicts tag indices for a given subword tokens batch

        Args:
            input_ids: indices of the subwords
            input_masks: mask that determines where to attend and where not to
            y_masks: mask which determines the first subword units in the the word

        Returns:
            Label indices or class probabilities for each token (not subtoken)

        """
        b_input_ids = torch.from_numpy(input_ids).to(self.device)
        b_input_masks = torch.from_numpy(input_masks).to(self.device)

        with torch.no_grad():
            # Forward pass, calculate logit predictions
            logits = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_masks)

            # Move logits and labels to CPU and to numpy arrays
            logits = token_from_subtoken(logits[0].detach().cpu(), torch.from_numpy(y_masks))

        if self.return_probas:
            pred = torch.nn.functional.softmax(logits, dim=-1)
            pred = pred.detach().cpu().numpy()
        else:
            logits = logits.detach().cpu().numpy()
            pred = np.argmax(logits, axis=-1)
            seq_lengths = np.sum(y_masks, axis=1)
            pred = [p[:l] for l, p in zip(seq_lengths, pred)]

        return pred

    @overrides
    def load(self, fname=None):
        if fname is not None:
            self.load_path = fname

        if self.pretrained_bert and not Path(self.pretrained_bert).is_file():
            self.model = BertForTokenClassification.from_pretrained(
                self.pretrained_bert, num_labels=self.n_classes,
                output_attentions=False, output_hidden_states=False)
        elif self.bert_config_file and Path(self.bert_config_file).is_file():
            self.bert_config = BertConfig.from_json_file(str(expand_path(self.bert_config_file)))

            if self.attention_probs_keep_prob is not None:
                self.bert_config.attention_probs_dropout_prob = 1.0 - self.attention_probs_keep_prob
            if self.hidden_keep_prob is not None:
                self.bert_config.hidden_dropout_prob = 1.0 - self.hidden_keep_prob
            self.model = BertForTokenClassification(config=self.bert_config)
        else:
            raise ConfigError("No pre-trained BERT model is given.")

        self.model.to(self.device)
        
        self.optimizer = getattr(torch.optim, self.optimizer_name)(
            self.model.parameters(), **self.optimizer_parameters)
        if self.lr_scheduler_name is not None:
            self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)(
                self.optimizer, **self.lr_scheduler_parameters)

        if self.load_path:
            log.info(f"Load path {self.load_path} is given.")
            if isinstance(self.load_path, Path) and not self.load_path.parent.is_dir():
                raise ConfigError("Provided load path is incorrect!")

            weights_path = Path(self.load_path.resolve())
            weights_path = weights_path.with_suffix(f".pth.tar")
            if weights_path.exists():
                log.info(f"Load path {weights_path} exists.")
                log.info(f"Initializing `{self.__class__.__name__}` from saved.")

                # now load the weights, optimizer from saved
                log.info(f"Loading weights from {weights_path}.")
                checkpoint = torch.load(weights_path, map_location=self.device)
                self.model.load_state_dict(checkpoint["model_state_dict"])
                self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
                self.epochs_done = checkpoint.get("epochs_done", 0)
            else:
                log.info(f"Init from scratch. Load path {weights_path} does not exist.")