def build_examples(file_path, max_seq_len, masked_lm_prob, max_predictions_per_seq, vocab_list):
    f = open(file_path, 'r')
    lines = f.readlines()
    examples = []
    max_num_tokens = max_seq_len - 2
    for line_cnt, line in enumerate(lines):
        if line_cnt % 50000 == 0:
            logger.info(f"Loading line {line_cnt}")
        example = {}
        guid = f'corpus-{line_cnt}'
        tokens_a = line.strip("\n").split(" ")[:max_num_tokens]
        tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
        segment_ids = [0 for _ in range(len(tokens_a) + 2)]
        # remove too short sample
        if len(tokens_a) < 5:
            continue
        tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(
            tokens, masked_lm_prob, max_predictions_per_seq, vocab_list)
        if line_cnt < 2:
            print("-------------------------Example-----------------------")
            print("guid: %s" % (guid))
            print("tokens: %s" % " ".join([str(x) for x in tokens]))
            print("masked_lm_labels: %s" % " ".join([str(x) for x in masked_lm_labels]))
            print("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
            print("masked_lm_positions: %s" % " ".join([str(x) for x in masked_lm_positions]))
        example['guid'] = guid
        example['tokens'] = tokens
        example['segment_ids'] = segment_ids
        example['masked_lm_positions'] = masked_lm_positions
        example['masked_lm_labels'] = masked_lm_labels
        examples.append(example)
    f.close()
    return examples
Esempio n. 2
0
 def __init__(self, training_path, file_id, tokenizer, reduce_memory=False):
     self.tokenizer = tokenizer
     self.file_id = file_id
     data_file = training_path / f"file_{self.file_id}.json"
     metrics_file = training_path / f"file_{self.file_id}_metrics.json"
     assert data_file.is_file() and metrics_file.is_file()
     metrics = json.loads(metrics_file.read_text())
     num_samples = metrics['num_training_examples']
     seq_len = metrics['max_seq_len']
     self.temp_dir = None
     self.working_dir = None
     if reduce_memory:
         self.temp_dir = TemporaryDirectory()
         self.working_dir = Path(self.temp_dir.name)
         input_ids = np.memmap(filename=self.working_dir / 'input_ids.memmap',
                               mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
         input_masks = np.memmap(filename=self.working_dir / 'input_masks.memmap',
                                 shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
         segment_ids = np.memmap(filename=self.working_dir / 'segment_ids.memmap',
                                 shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
         lm_label_ids = np.memmap(filename=self.working_dir / 'lm_label_ids.memmap',
                                  shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
         lm_label_ids[:] = -1
     else:
         input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
         input_masks = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
         segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
         lm_label_ids = np.full(shape=(num_samples, seq_len), dtype=np.int32, fill_value=-1)
     logger.info(f"Loading training examples for {str(data_file)}")
     with data_file.open() as f:
         for i, line in enumerate(f):
             line = line.strip()
             example = json.loads(line)
             features = convert_example_to_features(example, tokenizer, seq_len)
             input_ids[i] = features.input_ids
             segment_ids[i] = features.segment_ids
             input_masks[i] = features.input_mask
             lm_label_ids[i] = features.lm_label_ids
     assert i == num_samples - 1  # Assert that the sample count metric was true
     logger.info("Loading complete!")
     self.num_samples = num_samples
     self.seq_len = seq_len
     self.input_ids = input_ids
     self.input_masks = input_masks
     self.segment_ids = segment_ids
     self.lm_label_ids = lm_label_ids
Esempio n. 3
0
def data_aug1(data):
    new_data = []
    i = 0
    for line in data:
        tags = [x.split("-")[1] for x in line['tag'].split(" ") if "-" in x]
        tags = list(set(tags))
        if ('b' in tags or 'a' in tags) and 'c' in tags:
            c_ = []
            t_ = []
            context = line['context'].split(" ")
            raw_tags = line['tag'].split(" ")
            for c, t in zip(context, raw_tags):
                if 'c' in t:
                    continue
                c_.append(c)
                t_.append(t)
            if i <= 5:
                logger.info("--------- data aug1 -----------")
                logger.info(f"raw: {line['context']}")
                logger.info(f'new: {" ".join(c_)}')
                logger.info(f"raw_tag: {line['tag']}")
                logger.info(f'tag: {" ".join(t_)}')
                i += 1
            new_data.append({"context": " ".join(c_),
                             "tag": " ".join(t_),
                             'id': line['id'],
                             'raw_context': line['raw_context']})
        else:
            continue
    logger.info(f"data aug size: {len(new_data)}")
    return new_data
Esempio n. 4
0
def make_folds(args):
    train = []
    train_path = config['data_dir'] / 'train.txt'
    with open(str(train_path), 'r') as fr:
        idx = 0
        for line in fr:
            json_d = {}
            line = line.strip("\n")
            context = []
            tags = []
            lines = line.split("  ")
            for seg in lines:
                segs = seg.split("/")
                seg_text = segs[0].split("_")
                seg_label = segs[1]
                context.extend(seg_text)
                if seg_label == 'o':
                    tags.extend(["O"] * len(seg_text))
                elif len(seg_text) == 1:
                    tags.extend([f"S-{seg_label}"])
                else:
                    head_label = f"B-{seg_label}"
                    tags.extend([head_label])
                    tags.extend([f"I-{seg_label}"] * (len(seg_text) - 1))
            json_d['id'] = idx
            json_d['context'] = " ".join(context)
            json_d['tag'] = " ".join(tags)
            json_d['raw_context'] = line
            la = [x.split("-")[1] for x in tags if '-' in x]
            la = list(set(la))
            if len(la) == 0:
                y = 0
            elif len(la) == 3:
                y = 4
            elif len(la) == 2:
                if 'a' in la and 'b' in la:
                    y = 1
                if 'a' in la and 'c' in la:
                    y = 2
                if 'b' in la and 'c' in la:
                    y = 3
            elif len(la) == 1:
                if la[0] == 'a':
                    y = 5
                if la[0] == 'b':
                    y = 6
                if la[0] == 'c':
                    y = 7
            else:
                raise ValueError("tag is error")
            json_d['y'] = y
            idx += 1
            train.append(json_d)

    y_counter = Counter()
    y_counter.update([x['y'] for x in train])
    print(y_counter)
    X = train
    y = [d['y'] for d in train]
    sss = StratifiedKFold(n_splits=args.folds, random_state=args.seed, shuffle=True)
    for fold, (train_index, test_index) in enumerate(sss.split(X, y)):
        logger.info(f'fold-{fold} info:')
        logger.info(f'raw train data size: {len(train_index)}')
        logger.info(f'raw valid data size: {len(test_index)}')
        X_train = [X[i] for i in train_index]
        if args.do_aug:
            new_data1 = data_aug1(X_train)
            new_data2 = data_aug2(X_train)
            new_data3 = data_aug3(X_train)
            X_train.extend(new_data1)
            X_train.extend(new_data2)
            X_train.extend(new_data3)
            logger.info(f"After data augmentation, train data size: {len(X_train)}")
        X_test = [X[i] for i in test_index]
        train_file_name = f'{args.data_name}_train_fold_{fold}.pkl'
        dev_file_name = f'{args.data_name}_valid_fold_{fold}.pkl'
        save_pickle(X_train, file_path=config['data_dir'] / train_file_name)
        save_pickle(X_test, file_path=config['data_dir'] / dev_file_name)
def main():
    parser = ArgumentParser()
    parser.add_argument("--arch", default='bert_lstm_span', type=str)
    parser.add_argument("--do_train", action='store_true')
    parser.add_argument("--do_test", action='store_true')
    parser.add_argument("--save_best", action='store_true')
    parser.add_argument("--do_lower_case", action='store_true')
    parser.add_argument('--soft_label', action='store_true')
    parser.add_argument('--data_name', default='datagrand', type=str)
    parser.add_argument('--optimizer',
                        default='adam',
                        type=str,
                        choices=['adam', 'lookahead'])
    parser.add_argument('--markup',
                        default='bios',
                        type=str,
                        choices=['bio', 'bios'])
    parser.add_argument('--checkpoint', default=900000, type=int)
    parser.add_argument('--fold', default=0, type=int)
    parser.add_argument("--epochs", default=50.0, type=int)
    parser.add_argument("--resume_path", default='', type=str)
    parser.add_argument("--mode", default='max', type=str)
    parser.add_argument("--monitor", default='valid_f1', type=str)
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--sorted",
                        default=1,
                        type=int,
                        help='1 : True  0:False ')
    parser.add_argument("--n_gpu",
                        type=str,
                        default='0',
                        help='"0,1,.." or "0" or "" ')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
    parser.add_argument("--train_batch_size", default=24, type=int)
    parser.add_argument('--eval_batch_size', default=48, type=int)
    parser.add_argument("--train_max_seq_len", default=128, type=int)
    parser.add_argument("--eval_max_seq_len", default=512, type=int)
    parser.add_argument('--loss_scale', type=float, default=0)
    parser.add_argument("--warmup_proportion", default=0.1, type=float)
    parser.add_argument("--weight_decay", default=0.01, type=float)
    parser.add_argument("--adam_epsilon", default=1e-8, type=float)
    parser.add_argument("--grad_clip", default=5.0, type=float)
    parser.add_argument("--learning_rate", default=1e-4, type=float)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument("--no_cuda", action='store_true')
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('--fp16_opt_level', type=str, default='O1')
    args = parser.parse_args()

    args.pretrain_model = config[
        'checkpoint_dir'] / f'lm-checkpoint-{args.checkpoint}'
    args.device = torch.device(
        f"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.arch = args.arch + f"_{args.markup}_fold_{args.fold}"
    if args.optimizer == 'lookahead':
        args.arch += "_lah"
    args.model_path = config['checkpoint_dir'] / args.arch
    args.model_path.mkdir(exist_ok=True)
    # Good practice: save your training arguments together with the trained model
    torch.save(args, config['checkpoint_dir'] / 'training_args.bin')
    seed_everything(args.seed)
    init_logger(log_file=config['log_dir'] / f"{args.arch}.log")
    logger.info("Training/evaluation parameters %s", args)

    if args.do_train:
        run_train(args)

    if args.do_test:
        run_test(args)
def run_train(args):
    processor = BertProcessor(vocab_path=args.pretrain_model / 'vocab.txt',
                              do_lower_case=args.do_lower_case)
    processor.tokenizer.save_vocabulary(str(args.model_path))
    label_list = processor.get_labels()
    label2id = {label: i for i, label in enumerate(label_list)}

    train_data = processor.get_train(
        config['data_dir'] / f"{args.data_name}_train_fold_{args.fold}.pkl")
    train_examples = processor.create_examples(lines=train_data,
                                               example_type='train',
                                               cached_file=config['data_dir'] /
                                               f"cached_train_span_examples")
    train_features = processor.create_features(
        examples=train_examples,
        max_seq_len=args.train_max_seq_len,
        cached_file=config['data_dir'] /
        "cached_train_span_features_{}".format(args.train_max_seq_len))
    train_dataset = processor.create_dataset(train_features,
                                             is_sorted=args.sorted)
    if args.sorted:
        train_sampler = SequentialSampler(train_dataset)
    else:
        train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    valid_data = processor.get_dev(
        config['data_dir'] / f'{args.data_name}_valid_fold_{args.fold}.pkl')
    valid_examples = processor.create_examples(lines=valid_data,
                                               example_type='valid',
                                               cached_file=config['data_dir'] /
                                               f"cached_valid_span_examples")
    valid_features = processor.create_features(
        examples=valid_examples,
        max_seq_len=args.eval_max_seq_len,
        cached_file=config['data_dir'] /
        "cached_valid_span_features_{}".format(args.eval_max_seq_len))

    logger.info("initializing model")
    if args.resume_path:
        args.resume_path = Path(args.resume_path)
        model = BERTLSTMSpan.from_pretrained(args.resume_path,
                                             label2id=label2id,
                                             soft_label=args.soft_label)
    else:
        model = BERTLSTMSpan.from_pretrained(args.pretrain_model,
                                             label2id=label2id,
                                             soft_label=args.soft_label)
    model = model.to(args.device)
    t_total = int(
        len(train_dataloader) / args.gradient_accumulation_steps * args.epochs)

    bert_param_optimizer = list(model.bert.named_parameters())
    lstm_param_optimizer = list(model.bilstm.named_parameters())
    start_fc_param_optimizer = list(model.start_fc.named_parameters())
    end_fc_param_optimizer = list(model.end_fc.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in bert_param_optimizer
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.01,
        'lr':
        args.learning_rate
    }, {
        'params': [
            p for n, p in bert_param_optimizer
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0,
        'lr':
        args.learning_rate
    }, {
        'params': [
            p for n, p in lstm_param_optimizer
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.01,
        'lr':
        0.0005
    }, {
        'params': [
            p for n, p in lstm_param_optimizer
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0,
        'lr':
        0.0005
    }, {
        'params': [
            p for n, p in start_fc_param_optimizer
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.01,
        'lr':
        0.0005
    }, {
        'params': [
            p for n, p in start_fc_param_optimizer
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0,
        'lr':
        0.0005
    }, {
        'params': [
            p for n, p in end_fc_param_optimizer
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.01,
        'lr':
        0.0005
    }, {
        'params': [
            p for n, p in end_fc_param_optimizer
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0,
        'lr':
        0.0005
    }]
    if args.optimizer == 'adam':
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)
    else:
        base_optimizer = BertAdam(optimizer_grouped_parameters,
                                  lr=args.learning_rate,
                                  warmup=args.warmup_proportion,
                                  t_total=t_total)
        optimizer = Lookahead(base_optimizer, k=5, alpha=0.5)

    lr_scheduler = BERTReduceLROnPlateau(optimizer,
                                         lr=args.learning_rate,
                                         mode=args.mode,
                                         factor=0.5,
                                         patience=5,
                                         verbose=1,
                                         epsilon=1e-8,
                                         cooldown=0,
                                         min_lr=0,
                                         eps=1e-8)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    logger.info("initializing callbacks")
    train_monitor = TrainingMonitor(file_dir=config['figure_dir'],
                                    arch=args.arch)
    model_checkpoint = ModelCheckpoint(checkpoint_dir=args.model_path,
                                       mode=args.mode,
                                       monitor=args.monitor,
                                       arch=args.arch,
                                       save_best_only=args.save_best)

    # **************************** training model ***********************
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_examples))
    logger.info("  Num Epochs = %d", args.epochs)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    trainer = Trainer(
        n_gpu=args.n_gpu,
        model=model,
        logger=logger,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        label2id=label2id,
        criterion=SpanLoss(),
        training_monitor=train_monitor,
        fp16=args.fp16,
        resume_path=args.resume_path,
        grad_clip=args.grad_clip,
        model_checkpoint=model_checkpoint,
        gradient_accumulation_steps=args.gradient_accumulation_steps)
    trainer.train(train_data=train_dataloader,
                  valid_data=valid_features,
                  epochs=args.epochs,
                  seed=args.seed)
def main():
    parser = ArgumentParser()
    parser.add_argument("--do_data", default=False, action='store_true')
    parser.add_argument("--do_corpus", default=False, action='store_true')
    parser.add_argument("--do_vocab", default=False, action='store_true')
    parser.add_argument("--do_split", default=False, action='store_true')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--min_freq', default=0, type=int)
    parser.add_argument("--line_per_file", default=1000000000, type=int)
    parser.add_argument("--file_num",
                        type=int,
                        default=10,
                        help="Number of dynamic masking to pregenerate")
    parser.add_argument("--max_seq_len", type=int, default=128)
    parser.add_argument(
        "--short_seq_prob",
        type=float,
        default=0.1,
        help="Probability of making a short sentence as a training example")
    parser.add_argument(
        "--masked_lm_prob",
        type=float,
        default=0.15,
        help="Probability of masking each token for the LM task")
    parser.add_argument(
        "--max_predictions_per_seq",
        type=int,
        default=20,
        help="Maximum number of tokens to mask in each sequence")
    args = parser.parse_args()
    seed_everything(args.seed)
    vocab = Vocabulary(min_freq=args.min_freq, add_unused=False)
    if args.do_corpus:
        corpus = []
        train_path = str(config['data_dir'] / 'train.txt')
        with open(train_path, 'r') as fr:
            for ex_id, line in enumerate(fr):
                line = line.strip("\n")
                lines = [
                    " ".join(x.split("/")[0].split("_"))
                    for x in line.split("  ")
                ]
                if ex_id == 0:
                    logger.info(f"Train example: {' '.join(lines)}")
                corpus.append(" ".join(lines))
        test_path = str(config['data_dir'] / 'test.txt')
        with open(test_path, 'r') as fr:
            for ex_id, line in enumerate(fr):
                line = line.strip("\n")
                lines = line.split("_")
                if ex_id == 0:
                    logger.info(f"Test example: {' '.join(lines)}")
                corpus.append(" ".join(lines))
        corpus_path = str(config['data_dir'] / 'corpus.txt')
        with open(corpus_path, 'r') as fr:
            for ex_id, line in enumerate(fr):
                line = line.strip("\n")
                lines = line.split("_")
                if ex_id == 0:
                    logger.info(f"Corpus example: {' '.join(lines)}")
                corpus.append(" ".join(lines))
        corpus = list(set(corpus))
        logger.info(f"corpus size: {len(corpus)}")
        random_order = list(range(len(corpus)))
        np.random.shuffle(random_order)
        corpus = [corpus[i] for i in random_order]
        new_corpus_path = config['data_dir'] / "corpus/corpus.txt"
        if not new_corpus_path.exists():
            new_corpus_path.parent.mkdir(exist_ok=True)
        with open(new_corpus_path, 'w') as fr:
            for line in corpus:
                fr.write(line + "\n")

    if args.do_split:
        new_corpus_path = config['data_dir'] / "corpus/corpus.txt"
        split_save_path = config['data_dir'] / "corpus/train"
        if not split_save_path.exists():
            split_save_path.mkdir(exist_ok=True)
        line_per_file = args.line_per_file
        command = f'split -a 4 -l {line_per_file} -d {new_corpus_path} {split_save_path}/shard_'
        os.system(f"{command}")

    if args.do_vocab:
        vocab.read_data(data_path=config['data_dir'] / "corpus/train")
        vocab.build_vocab()
        vocab.save(file_path=config['data_dir'] / 'corpus/vocab_mapping.pkl')
        vocab.save_bert_vocab(file_path=config['checkpoint_dir'] / 'vocab.txt')
        logger.info(f"vocab size: {len(vocab)}")
        bert_base_config['vocab_size'] = len(vocab)
        save_json(data=bert_base_config,
                  file_path=config['checkpoint_dir'] / 'config.json')

    if args.do_data:
        vocab_list = vocab.load_bert_vocab(config['checkpoint_dir'] /
                                           'vocab.txt')
        data_path = config['data_dir'] / "corpus/train"
        files = sorted([
            f for f in data_path.iterdir() if f.exists() and "." not in str(f)
        ])
        logger.info("--- pregenerate training data parameters ---")
        logger.info(f'max_seq_len: {args.max_seq_len}')
        logger.info(f"max_predictions_per_seq: {args.max_predictions_per_seq}")
        logger.info(f"masked_lm_prob: {args.masked_lm_prob}")
        logger.info(f"seed: {args.seed}")
        logger.info(f"file num : {args.file_num}")
        for idx in range(args.file_num):
            logger.info(f"pregenetate file_{idx}.json")
            save_filename = data_path / f"file_{idx}.json"
            num_instances = 0
            with save_filename.open('w') as fw:
                for file_idx in range(len(files)):
                    file_path = files[file_idx]
                    file_examples = build_examples(
                        file_path,
                        max_seq_len=args.max_seq_len,
                        masked_lm_prob=args.masked_lm_prob,
                        max_predictions_per_seq=args.max_predictions_per_seq,
                        vocab_list=vocab_list)
                    file_examples = [
                        json.dumps(instance) for instance in file_examples
                    ]
                    for instance in file_examples:
                        fw.write(instance + '\n')
                        num_instances += 1
            metrics_file = data_path / f"file_{idx}_metrics.json"
            print(f"num_instances: {num_instances}")
            with metrics_file.open('w') as metrics_file:
                metrics = {
                    "num_training_examples": num_instances,
                    "max_seq_len": args.max_seq_len
                }
                metrics_file.write(json.dumps(metrics))
Esempio n. 8
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--file_num", type=int, default=10,
                        help="Number of pregenerate file")
    parser.add_argument("--reduce_memory", action="store_true",
                        help="Store training data as on-disc memmaps to massively reduce memory usage")
    parser.add_argument("--epochs", type=int, default=4,
                        help="Number of epochs to train for")
    parser.add_argument('--num_eval_steps', default=2000)
    parser.add_argument('--num_save_steps', default=5000)
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--train_batch_size", default=18, type=int,
                        help="Total batch size for training.")
    parser.add_argument('--loss_scale', type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--warmup_proportion", default=0.1, type=float,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument('--max_grad_norm', default=1.0, type=float)
    parser.add_argument("--learning_rate", default=2e-4, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")
    parser.add_argument('--fp16_opt_level', type=str, default='O2',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    args = parser.parse_args()

    pregenerated_data = config['data_dir'] / "corpus/train"
    assert pregenerated_data.is_dir(), \
        "--pregenerated_data should point to the folder of files made by prepare_lm_data_mask.py!"

    samples_per_epoch = 0
    for i in range(args.file_num):
        data_file = pregenerated_data / f"file_{i}.json"
        metrics_file = pregenerated_data / f"file_{i}_metrics.json"
        if data_file.is_file() and metrics_file.is_file():
            metrics = json.loads(metrics_file.read_text())
            samples_per_epoch += metrics['num_training_examples']
        else:
            if i == 0:
                exit("No training data was found!")
            print(f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs}).")
            print("This script will loop over the available data, but training diversity may be negatively impacted.")
            break
    logger.info(f"samples_per_epoch: {samples_per_epoch}")
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(f"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        args.n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        f"device: {device} , distributed training: {bool(args.local_rank != -1)}, 16-bits training: {args.fp16}")

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1")
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    seed_everything(args.seed)
    tokenizer = BertTokenizer(vocab_file=config['checkpoint_dir'] / 'vocab.txt')
    total_train_examples = samples_per_epoch * args.epochs

    num_train_optimization_steps = int(
        total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
    args.warmup_steps = int(num_train_optimization_steps * args.warmup_proportion)

    # Prepare model
    with open(str(config['checkpoint_dir'] / 'config.json'), "r", encoding='utf-8') as reader:
        json_config = json.loads(reader.read())
    print(json_config)
    bert_config = BertConfig.from_json_file(str(config['checkpoint_dir'] / 'config.json'))
    model = BertForMaskedLM(config=bert_config)
    # model = BertForMaskedLM.from_pretrained(config['checkpoint_dir'] / 'checkpoint-580000')
    model.to(device)
    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank)
    global_step = 0
    metric = LMAccuracy()
    tr_acc = AverageMeter()
    tr_loss = AverageMeter()

    train_logs = {}
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {total_train_examples}")
    logger.info(f"  Batch size = {args.train_batch_size}")
    logger.info(f"  Num steps = {num_train_optimization_steps}")
    logger.info(f"  warmup_steps = {args.warmup_steps}")

    seed_everything(args.seed)  # Added here for reproducibility
    for epoch in range(args.epochs):
        for idx in range(args.file_num):
            epoch_dataset = PregeneratedDataset(file_id=idx, training_path=pregenerated_data, tokenizer=tokenizer,
                                                reduce_memory=args.reduce_memory)
            if args.local_rank == -1:
                train_sampler = RandomSampler(epoch_dataset)
            else:
                train_sampler = DistributedSampler(epoch_dataset)
            train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
            model.train()
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(train_dataloader):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids = batch
                outputs = model(input_ids=input_ids, token_type_ids=segment_ids,
                                attention_mask=input_mask, masked_lm_labels=lm_label_ids)
                pred_output = outputs[1]
                loss = outputs[0]
                metric(logits=pred_output.view(-1, bert_config.vocab_size), target=lm_label_ids.view(-1))
                if args.n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                nb_tr_steps += 1
                tr_acc.update(metric.value(), n=input_ids.size(0))
                tr_loss.update(loss.item(), n=1)

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    lr_scheduler.step()
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if global_step % args.num_eval_steps == 0:
                    train_logs['loss'] = tr_loss.avg
                    train_logs['acc'] = tr_acc.avg
                    show_info = f'\n[Training]:[{epoch}/{args.epochs}]{global_step}/{num_train_optimization_steps} ' + "-".join(
                        [f' {key}: {value:.4f} ' for key, value in train_logs.items()])
                    logger.info(show_info)
                    tr_acc.reset()
                    tr_loss.reset()

                if global_step % args.num_save_steps == 0:
                    if args.local_rank in [-1, 0] and args.num_save_steps > 0:
                        # Save model checkpoint
                        output_dir = config['checkpoint_dir'] / f'lm-checkpoint-{global_step}'
                        if not output_dir.exists():
                            output_dir.mkdir()
                        # save model
                        model_to_save = model.module if hasattr(model,
                                                                'module') else model  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(str(output_dir))
                        torch.save(args, str(output_dir / 'training_args.bin'))
                        logger.info("Saving model checkpoint to %s", output_dir)

                        # save config
                        output_config_file = output_dir / CONFIG_NAME
                        with open(str(output_config_file), 'w') as f:
                            f.write(model_to_save.config.to_json_string())

                        # save vocab
                        tokenizer.save_vocabulary(output_dir)