예제 #1
0
def pretrain(args, data_path):
    print('[pretrain] create config, model')
    if args.model == 'bert':
        if args.redefined_tokenizer:
            bert_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
                                                           do_lower_case=True)
        else:
            bert_tokenizer = BertTokenizer.from_pretrained(
                './pretrained_weights/bert-base-uncased-vocab.txt',
                do_lower_case=True)
    elif args.model == 'biobert':
        if args.redefined_tokenizer:
            bert_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
                                                           do_lower_case=False)
        else:
            bert_tokenizer = BertTokenizer.from_pretrained(
                './pretrained_weights/biobert_pretrain_output_all_notes_150000/vocab.txt',
                do_lower_case=False)
    elif args.model == 'bert-tiny':
        if args.redefined_tokenizer:
            bert_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
                                                           do_lower_case=True)
        else:
            bert_tokenizer = BertTokenizer.from_pretrained(
                './pretrained_weights/bert-tiny-uncased-vocab.txt',
                do_lower_case=True)

    if args.model == 'bert':
        config = BertConfig.from_pretrained(
            './pretrained_weights/bert-base-uncased-config.json')
        if args.Y == 'full':
            config.Y = 8921
        else:
            config.Y = int(args.Y)
        config.gpu = args.gpu
        config.redefined_vocab_size = len(bert_tokenizer)
        if args.max_sequence_length is None:
            config.redefined_max_position_embeddings = MAX_LENGTH
        else:
            config.redefined_max_position_embeddings = args.max_sequence_length
        config.last_module = args.last_module
        config.model = args.model

        if args.from_scratch:
            model = BertForMaskedLM(config=config)
        else:
            model = BertForMaskedLM.from_pretrained(
                './pretrained_weights/bert-base-uncased-pytorch_model.bin',
                config=config)
    elif args.model == 'biobert':
        config = BertConfig.from_pretrained(
            './pretrained_weights/biobert_pretrain_output_all_notes_150000/bert_config.json'
        )
        if args.Y == 'full':
            config.Y = 8921
        else:
            config.Y = int(args.Y)
        config.gpu = args.gpu
        config.redefined_vocab_size = len(bert_tokenizer)
        if args.max_sequence_length is None:
            config.redefined_max_position_embeddings = MAX_LENGTH
        else:
            config.redefined_max_position_embeddings = args.max_sequence_length
        config.last_module = args.last_module
        config.model = args.model
        if args.from_scratch:
            bert_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
                                                           do_lower_case=False)
        else:
            bert_tokenizer = BertTokenizer.from_pretrained(
                './pretrained_weights/biobert_pretrain_output_all_notes_150000/vocab.txt',
                do_lower_case=False)
        config.redefined_vocab_size = len(bert_tokenizer)
        if args.max_sequence_length is None:
            config.redefined_max_position_embeddings = MAX_LENGTH
        else:
            config.redefined_max_position_embeddings = args.max_sequence_length
        config.model = args.model
        if args.from_scratch:
            model = BertForMaskedLM(config=config)
        else:
            model = BertForMaskedLM.from_pretrained(
                './pretrained_weights/biobert_pretrain_output_all_notes_150000/pytorch_model.bin',
                config=config)
    elif args.model == 'bert-tiny':
        config = BertConfig.from_pretrained(
            './pretrained_weights/bert-tiny-uncased-config.json')
        if args.Y == 'full':
            config.Y = 8921
        else:
            config.Y = int(args.Y)
        config.gpu = args.gpu
        config.redefined_vocab_size = len(bert_tokenizer)
        if args.max_sequence_length is None:
            config.redefined_max_position_embeddings = MAX_LENGTH
        else:
            config.redefined_max_position_embeddings = args.max_sequence_length
        config.last_module = args.last_module
        config.model = args.model
        if args.from_scratch:
            model = BertForMaskedLM(config=config)
        else:
            model = BertForMaskedLM.from_pretrained(
                './pretrained_weights/bert-tiny-uncased-pytorch_model.bin',
                config=config)

    if args.gpu:
        model.cuda()

    print('[pretrain] prepare optimizer, scheduler')
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    pretrain_optimizer = optim.Adam(optimizer_grouped_parameters,
                                    weight_decay=args.weight_decay,
                                    lr=args.lr)
    length = datasets.data_length(args.data_path, args.version)
    t_total = length // args.pretrain_batch_size * args.pretrain_epochs
    pretrain_scheduler = get_linear_schedule_with_warmup(pretrain_optimizer, \
                                                         num_warmup_steps=args.warmup_steps, \
                                                         num_training_steps=t_total, \
                                                        )

    print_every = 25

    model.train()
    model.zero_grad()

    print('[pretrain] create dataloader')
    train_dataset = datasets.pretrain_data_generator(
        args,
        data_path,
        args.pretrain_batch_size,
        version=args.version,
        bert_tokenizer=bert_tokenizer)

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.pretrain_batch_size)

    print('[pretrain] start epoch')
    for epoch in range(args.pretrain_epochs):
        losses = []
        for batch_idx, data in tqdm(enumerate(train_dataloader)):
            inputs, labels = random_mask_tokens(args, data, bert_tokenizer)
            if args.gpu:
                inputs = inputs.cuda()
                labels = labels.cuda()

            token_type_ids = (inputs > 0).long() * 0
            attention_mask = (inputs > 0).long()
            position_ids = torch.arange(inputs.size(1)).expand(
                inputs.size(0), inputs.size(1))
            if args.gpu:
                position_ids = position_ids.cuda()
            position_ids = position_ids * (inputs > 0).long()

            outputs = model(input_ids=inputs, \
                            token_type_ids=token_type_ids, \
                            attention_mask=attention_mask, \
                            position_ids=position_ids, \
                            masked_lm_labels=labels, \
                           )
            loss = outputs[0]
            losses.append(loss.item())

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            pretrain_optimizer.step()
            pretrain_scheduler.step()
            model.zero_grad()

            if batch_idx % print_every == 0:
                # print the average loss of the last 10 batches
                print(
                    "Train epoch: {} [batch #{}, batch_size {}, seq length {}]\tLoss: {:.6f}"
                    .format(epoch, batch_idx,
                            data.size()[0],
                            data.size()[1], np.mean(losses[-10:])))

        loss = sum(losses) / len(losses)
        print('Epoch %d: %.4f' % (epoch, loss))

    model.save_pretrained(args.pretrain_ckpt_dir)
    print('Save pretrained model --> %s' % (args.pretrain_ckpt_dir))
예제 #2
0
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument("-conf", type=str)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument(
        "--gpu",
        type=str,
        default=None,
        help=
        "binary flag which gpu to use (For example '10100000' means use device_id=0 and 2)"
    )

    args = parser.parse_args()
    config = configparser.ConfigParser()
    config.read(args.conf)

    hidden_size = int(config["model"]["hidden_size"])
    num_hidden_layers = int(config["model"]["num_hidden_layers"])
    num_attention_heads = int(config["model"]["num_attention_heads"])
    intermediate_size = int(config["model"]["intermediate_size"])
    max_position_embeddings = int(config["model"]["max_position_embeddings"])
    #
    vocab_size = int(config["vocab"]["vocab_size"])
    mask_id = int(config["vocab"]["mask_id"])
    #
    log_path = config["log"]["log_path"]
    log_dir = os.path.dirname(log_path)
    os.makedirs(log_dir, exist_ok=True)
    log_step = int(config["log"]["log_step"])
    #
    train_size = int(config["data"]["train_size"])
    #
    save_prefix = config["save"]["save_prefix"]
    save_dir = os.path.dirname(save_prefix)
    os.makedirs(save_dir, exist_ok=True)
    save_epoch = int(config["save"]["save_epoch"])
    #
    batch_size = int(config["train"]["batch_size"])
    if args.debug:
        batch_size = 10
    num_epochs = int(config["train"]["num_epochs"])
    learning_rate = float(config["train"]["learning_rate"])
    warmup_proportion = float(config["train"]["warmup_proportion"])
    weight_decay = float(config["train"]["weight_decay"])
    #
    num_to_mask = int(config["mask"]["num_to_mask"])
    max_seq_len = int(config["mask"]["max_seq_len"])

    if args.debug:
        logging.basicConfig(format="%(asctime)s %(message)s",
                            level=logging.DEBUG)
    else:
        logging.basicConfig(filename=log_path,
                            format="%(asctime)s %(message)s",
                            level=logging.DEBUG)

    bertconfig = modeling_bert.BertConfig(
        vocab_size_or_config_json_file=vocab_size,
        hidden_size=hidden_size,
        num_hidden_layers=num_hidden_layers,
        num_attention_heads=num_attention_heads,
        intermediate_size=intermediate_size,
        max_position_embeddings=max_position_embeddings)
    model = BertForMaskedLM(config=bertconfig)
    total_params = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)

    if args.gpu is not None:
        device_ids = []
        for device_id, flag in enumerate(args.gpu):
            if flag == "1":
                device_ids.append(device_id)
        multi_gpu = True
        device = torch.device("cuda:{}".format(device_ids[0]))
    else:
        multi_gpu = False
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    logging.info(f"device: {device}")
    if "model_path" in config["train"]:
        model_path = config["train"]["model_path"]
        state_dict = torch.load(model_path, map_location=device)
        model.load_state_dict(state_dict)
        logging.info(f"load model from {model_path}")
    model.to(device)
    if multi_gpu:
        logging.info(f"GPU: device_id={device_ids}")
        model = torch.nn.DataParallel(model, device_ids=device_ids)
    model.train()

    # optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    t_total = (train_size // batch_size) * num_epochs
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=learning_rate,
                         warmup=warmup_proportion,
                         weight_decay=weight_decay,
                         t_total=t_total)
    logging.info("start training...")

    for epoch in range(num_epochs):
        if "train_dir" in config["data"]:
            train_dir = config["data"]["train_dir"]
            datpaths = os.listdir(train_dir)
            random.shuffle(datpaths)
            for step_ds, path in enumerate(datpaths):
                path = os.path.join(train_dir, path)
                dataset = LMDataset(path)
                num_steps = (len(dataset) // batch_size) + 1
                logging.info(f"dataset from: {path}")
                loss_ds = train_dataset(dataset=dataset,
                                        model=model,
                                        optimizer=optimizer,
                                        multi_gpu=multi_gpu,
                                        device=device,
                                        epoch=epoch,
                                        batch_size=batch_size,
                                        num_steps=num_steps,
                                        log_step=log_step,
                                        num_to_mask=num_to_mask,
                                        mask_id=mask_id,
                                        max_seq_len=max_seq_len)
                logging.info(
                    f"step {step_ds + 1} / {len(datpaths)}: {(loss_ds / num_steps):.6f}"
                )
        else:
            train_path = config["data"]["train_path"]
            dataset = LMDataset(train_path)
            num_steps = (len(dataset) // batch_size) + 1
            loss_epoch = train_dataset(dataset=dataset,
                                       model=model,
                                       optimizer=optimizer,
                                       multi_gpu=multi_gpu,
                                       device=device,
                                       epoch=epoch,
                                       batch_size=batch_size,
                                       num_steps=num_steps,
                                       log_step=log_step,
                                       num_to_mask=num_to_mask,
                                       mask_id=mask_id,
                                       max_seq_len=max_seq_len)
            logging.info(
                f"epoch {epoch + 1} / {num_epochs} : {(loss_epoch / num_steps):.6f}"
            )

        if (epoch + 1) % save_epoch == 0:
            save_path = f"{save_prefix}.network.epoch{(epoch + 1):d}"
            optimizer_save_path = f"{save_prefix}.optimizer.epoch{(epoch + 1):d}"
            if multi_gpu:
                torch.save(model.module.state_dict(),
                           save_path.format(epoch + 1))
            else:
                torch.save(model.state_dict(), save_path.format(epoch + 1))
            logging.info(f"model saved: {save_path}")
            torch.save(optimizer.state_dict(),
                       optimizer_save_path.format(epoch + 1))
            logging.info(f"optimizer saved: {optimizer_save_path}")