Пример #1
0
 def __init__(self, config, cuda_device=-1):
     self.cuda_device = False if cuda_device == -1 else True
     model_path = config['model_path']
     self.tokenizer = XLNetTokenizer.from_pretrained(model_path)
     if (self.cuda_device):
         self.model = XLNetForMultipleChoice.from_pretrained(model_path).to(
             'cuda')
     else:
         self.model = XLNetForMultipleChoice.from_pretrained(model_path)
Пример #2
0
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    print('gpu count:', n_gpu)

    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(random_seed)

    os.makedirs(output_dir, exist_ok=True)

    model_state_dict = torch.load(output_model_file, map_location=device)
    model = XLNetForMultipleChoice.from_pretrained('xlnet-base-cased',
                                                   state_dict=model_state_dict)
    model.to(device)
    tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')

    eval_data = load_and_cache_examples(data_path,
                                        'mc160',
                                        tokenizer,
                                        test=True)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=eval_batch_size)

    logger.info("***** Running Evaluation *****")
    logger.info("  Num examples = %d", len(eval_dataloader))
    logger.info("  Batch size = %d", eval_batch_size)
    model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    logits_all = []
    for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        label_ids = label_ids.to(device)

        with torch.no_grad():
            eval_output = model(input_ids=input_ids,
                                token_type_ids=segment_ids,
                                attention_mask=input_mask,
                                labels=label_ids)
        tmp_eval_loss = eval_output.loss
        logits = eval_output.logits
        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.to('cpu').numpy()
        for i in range(len(logits)):
            logits_all += [logits[i]]

        tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1))

        eval_loss += tmp_eval_loss.mean().item()
        eval_accuracy += tmp_eval_accuracy

        nb_eval_examples += input_ids.size(0)
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = eval_accuracy / nb_eval_examples

    result = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy}
    logger.info("***** Eval results *****")
    for key in sorted(result.keys()):
        logger.info("  %s = %s", key, str(result[key]))

    output_eval_file = os.path.join(output_dir, "results.txt")
    with open(output_eval_file, "a+") as writer:
        for key in sorted(result.keys()):
            writer.write("%s = %s\n" % (key, str(result[key])))
Пример #3
0
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    print('gpu count:', n_gpu)

    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(random_seed)

    os.makedirs(output_dir, exist_ok=True)

    model = XLNetForMultipleChoice.from_pretrained('xlnet-base-cased')
    model.to(device)
    tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')

    no_decay = ['bias', 'LayerNorm.weight']
    ## note: no weight decay according to XLNet paper
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=1e-6)

    train_data = load_and_cache_examples(data_path, 'race', tokenizer)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=train_batch_size)

    num_train_steps = len(
        train_dataloader) // gradient_accumulation_steps * num_train_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_train_steps)

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataloader))
    logger.info("  Batch size = %d", train_batch_size)
    logger.info("  Num steps = %d", num_train_steps)

    global_step = 0

    for ep in range(int(num_train_epochs)):
        model.train()
        max_score = 0
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch
            output = model(input_ids=input_ids,
                           token_type_ids=segment_ids,
                           attention_mask=input_mask,
                           labels=label_ids)
            loss = output.loss
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps
            loss.backward()
            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            if (step + 1) % gradient_accumulation_steps == 0:
                optimizer.step()  # We have accumulated enought gradients
                scheduler.step()
                model.zero_grad()
                global_step += 1

            if step % 800 == 0:
                logger.info("Training loss: {}, global step: {}".format(
                    tr_loss / nb_tr_steps, global_step))

        eval_data = load_and_cache_examples(data_path,
                                            'race',
                                            tokenizer,
                                            evaluate=True)
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=eval_batch_size)

        logger.info("***** Running Dev Evaluation *****")
        logger.info("  Num examples = %d", len(eval_dataloader))
        logger.info("  Batch size = %d", eval_batch_size)
        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        logits_all = []
        for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                eval_output = model(input_ids=input_ids,
                                    token_type_ids=segment_ids,
                                    attention_mask=input_mask,
                                    labels=label_ids)
            tmp_eval_loss = eval_output.loss
            logits = eval_output.logits
            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            for i in range(len(logits)):
                logits_all += [logits[i]]

            tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1))

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples

        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'global_step': global_step,
            'loss': tr_loss / nb_tr_steps
        }
        logger.info(" Epoch: %d", (ep + 1))
        logger.info("***** Eval results *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))

        output_eval_file = os.path.join(output_dir, "results.txt")
        with open(output_eval_file, "a+") as writer:
            writer.write(" Epoch: " + str(ep + 1))
            for key in sorted(result.keys()):
                writer.write("%s = %s\n" % (key, str(result[key])))

        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(
            output_dir, "pytorch_model_{}epoch.bin".format(ep + 1))
        torch.save(model_to_save.state_dict(), output_model_file)
Пример #4
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--model_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The model file")
    parser.add_argument("--seq_length",
                        default=None,
                        type=int,
                        required=True,
                        help="seq_length")
    parser.add_argument("--model",
                        default=None,
                        type=str,
                        required=True,
                        help="model")
    parser.add_argument("--dataset",
                        default=None,
                        type=str,
                        required=True,
                        help="dataset")
    args = parser.parse_args()
    global max_seq_length
    max_seq_length = args.seq_length
    output_model_file = args.model_file

    output_dir = 'model-' + args.dataset
    data_path = dataset_map[args.dataset]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    print('gpu count:', n_gpu)

    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(random_seed)

    os.makedirs(output_dir, exist_ok=True)

    model_state_dict = torch.load(output_model_file, map_location=device)
    model = XLNetForMultipleChoice.from_pretrained(args.model,
                                                   state_dict=model_state_dict,
                                                   dropout=0,
                                                   summary_last_dropout=0)
    model.to(device)
    tokenizer = XLNetTokenizer.from_pretrained(args.model)

    eval_data = load_and_cache_examples(data_path,
                                        args.dataset,
                                        args.model,
                                        tokenizer,
                                        test=True)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=eval_batch_size)

    logger.info("***** Running Evaluation *****")
    logger.info("  Num examples = %d", len(eval_data))
    logger.info("  Batch size = %d", eval_batch_size)
    model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0

    for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        label_ids = label_ids.to(device)

        with torch.no_grad():
            eval_output = model(input_ids=input_ids,
                                token_type_ids=segment_ids,
                                attention_mask=input_mask,
                                labels=label_ids)
        tmp_eval_loss = eval_output.loss
        logits = eval_output.logits
        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.to('cpu').numpy()

        tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1))

        eval_loss += tmp_eval_loss.mean().item()
        eval_accuracy += tmp_eval_accuracy

        nb_eval_examples += input_ids.size(0)
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = eval_accuracy / nb_eval_examples

    result = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy}
    logger.info("***** Eval results *****")
    for key in sorted(result.keys()):
        logger.info("  %s = %s", key, str(result[key]))

    output_eval_file = os.path.join(
        output_dir, "{}_{}_{}_results.txt".format(args.dataset, args.model,
                                                  max_seq_length))
    with open(output_eval_file, "a+") as writer:
        writer.write("Test:\n")
        for key in sorted(result.keys()):
            writer.write("%s = %s\n" % (key, str(result[key])))
Пример #5
0
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    print('gpu count:', n_gpu)
    
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(random_seed)
    os.makedirs(output_dir, exist_ok=True)
    
    model_state_dict = torch.load(output_model_file, map_location=device)
    model = XLNetForMultipleChoice.from_pretrained('xlnet-large-cased', state_dict=model_state_dict)
    logger.info("Trained model: {} loaded.".format(output_model_file))

    model.to(device)
    no_decay = ['bias', 'LayerNorm.weight']
    ## note: no weight decay according to XLNet paper 
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
    processor = processors['dream']()
    label_list = processor.get_labels()

    eval_examples = processor.get_test_examples('')
    eval_features = convert_examples_to_features(
            eval_examples,
            label_list,
            max_seq_length,
            tokenizer,
            pad_on_left=True,  # pad on the left for xlnet
            pad_token_segment_id=4
        )

    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", eval_batch_size)

    all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'), dtype=torch.long)
    all_input_mask = torch.tensor(select_field(eval_features, 'input_mask'), dtype=torch.long)
    all_segment_ids = torch.tensor(select_field(eval_features, 'segment_ids'), dtype=torch.long)
    all_label_ids = torch.tensor([f.label for f in eval_features], dtype=torch.long)
    
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size)

    model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    logits_all = []
    for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        label_ids = label_ids.to(device)

        with torch.no_grad():
            eval_output = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids, n_class=n_class)
        tmp_eval_loss = eval_output.loss
        logits = eval_output.logits
        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.to('cpu').numpy()
        for i in range(len(logits)):
            logits_all += [logits[i]]
        
        tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1))

        eval_loss += tmp_eval_loss.mean().item()
        eval_accuracy += tmp_eval_accuracy

        nb_eval_examples += input_ids.size(0)
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = eval_accuracy / nb_eval_examples

    result = {'eval_loss': eval_loss,
              'eval_accuracy': eval_accuracy}
    logger.info("***** Eval results *****")
    for key in sorted(result.keys()):
        logger.info("  %s = %s", key, str(result[key]))
        output_eval_file = os.path.join(output_dir, "results.txt")
    with open(output_eval_file, "a+") as writer:
        for key in sorted(result.keys()):
            writer.write("%s = %s\n" % (key, str(result[key])))
Пример #6
0
def train():

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    print("device:{} n_gpu:{}".format(device, n_gpu))
    seed = hyperparameters["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    max_seq_length = hyperparameters["max_sent_length"]
    gradient_accumulation_steps = hyperparameters[
        "gradient_accumulation_steps"]
    num_epochs = hyperparameters["num_epoch"]
    train_batch_size = hyperparameters["train_batch_size"] // hyperparameters[
        "gradient_accumulation_steps"]
    tokenizer = XLNetTokenizer.from_pretrained("xlnet-large-cased",
                                               do_lower_case=True)
    model = XLNetForMultipleChoice.from_pretrained("xlnet-large-cased")

    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    train_examples = read_examples('../dataset/train_bert.txt')
    dev_examples = read_examples('../dataset/test_bert.txt')
    nTrain = len(train_examples)
    nDev = len(dev_examples)
    num_train_optimization_steps = int(
        nTrain / train_batch_size / gradient_accumulation_steps) * num_epochs
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=hyperparameters["learning_rate"])
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * num_train_optimization_steps),
        num_training_steps=num_train_optimization_steps)

    global_step = 0
    train_features = convert_examples_to_features(train_examples, tokenizer,
                                                  max_seq_length)
    dev_features = convert_examples_to_features(dev_examples, tokenizer,
                                                max_seq_length)
    # """
    train_dataloader = get_train_dataloader(train_features, train_batch_size)
    dev_dataloader = get_eval_dataloader(dev_features,
                                         hyperparameters["eval_batch_size"])
    print("Num of train features:{}".format(nTrain))
    print("Num of dev features:{}".format(nDev))
    best_dev_accuracy = 0
    best_dev_epoch = 0
    no_up = 0

    epoch_tqdm = trange(int(num_epochs), desc="Epoch")
    for epoch in epoch_tqdm:
        model.train()

        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            batch = tuple(t.to(device) for t in batch)
            input_ids, label_ids = batch
            loss, logits = model(input_ids=input_ids, labels=label_ids)[:2]
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps
            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1
            loss.backward()
            if (step + 1) % gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

        train_loss, train_accuracy = evaluate(model, device, train_dataloader,
                                              "Train")
        dev_loss, dev_accuracy = evaluate(model, device, dev_dataloader, "Dev")

        if dev_accuracy > best_dev_accuracy:
            best_dev_accuracy = dev_accuracy
            best_dev_epoch = epoch + 1
            no_up = 0

        else:
            no_up += 1
        tqdm.write("\t ***** Eval results (Epoch %s) *****" % str(epoch + 1))
        tqdm.write("\t train_accuracy = %s" % str(train_accuracy))
        tqdm.write("\t dev_accuracy = %s" % str(dev_accuracy))
        tqdm.write("")
        tqdm.write("\t best_dev_accuracy = %s" % str(best_dev_accuracy))
        tqdm.write("\t best_dev_epoch = %s" % str(best_dev_epoch))
        tqdm.write("\t no_up = %s" % str(no_up))
        tqdm.write("")
        if no_up >= hyperparameters["patience"]:
            epoch_tqdm.close()
            break
Пример #7
0
def main(config, model_filename):
    if not os.path.exists(config.output_dir):
        os.makedirs(config.output_dir)

    if not os.path.exists(config.cache_dir):
        os.makedirs(config.cache_dir)

    model_file = os.path.join(config.output_dir, model_filename)

    # Prepare the device
    # gpu_ids = [int(device_id) for device_id in config.gpu_ids.split()]
    gpu_ids = [2]
    device, n_gpu = get_device(gpu_ids[0])
    if n_gpu > 1:
        n_gpu = len(gpu_ids)

    # Set Random Seeds
    random.seed(config.seed)
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(config.seed)
        torch.backends.cudnn.deterministic = True

    tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
    model = XLNetForMultipleChoice.from_pretrained(
        './xlnet_model')  # ./xlnet_model

    cache_train_dataset = "cached_dataset_train"
    cache_dev_dataset = "cached_dataset_dev"
    if os.path.exists(config.cache_dir + '/' + cache_train_dataset):
        logger.info("Loading features from cached file %s",
                    config.cache_dir + '/' + cache_train_dataset)
        train_dataset = torch.load(config.cache_dir + '/' +
                                   cache_train_dataset)
        dev_dataset = torch.load(config.cache_dir + '/' + cache_dev_dataset)
    else:
        train_dataset, dev_dataset, test_dataset = load_data(
            config.data_path, device, tokenizer, config.cache_dir, 128, 1024)
        logger.info("save cached file in  %s", config.cache_dir)
        torch.save(train_dataset, config.cache_dir + '/' + cache_train_dataset)
        torch.save(dev_dataset, config.cache_dir + '/' + cache_dev_dataset)
    train_sampler = RandomSampler(train_dataset)
    dev_sampler = RandomSampler(dev_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  shuffle=True,
                                  sampler=train_sampler,
                                  batch_size=config.train_batch_size,
                                  num_workers=8,
                                  pin_memory=False)
    dev_dataloader = DataLoader(dev_dataset,
                                shuffle=True,
                                sampler=dev_sampler,
                                batch_size=config.dev_batch_size,
                                num_workers=8,
                                pin_memory=False)
    # train_iterator = trange(int(config.epoch_num))
    # if config.model_name == "GAReader":
    #     from Bert_GAReader.GAReader.GAReader import GAReader
    #     model = GAReader(
    #         config.bert_word_dim, config.output_dim, config.hidden_size,
    #         config.rnn_num_layers, config.ga_layers, config.bidirectional,
    #         config.dropout, bert_config)
    #     print(model)
    # no_decay = ['bias', 'LayerNorm.weight']

    # optimizer = optim.Adam(model.parameters(), lr=config.lr)
    optimizer = optim.SGD(model.parameters(), lr=config.lr)
    # optimizer = optim.AdamW(optimizer_grouped_parameter,lr=config.lr)
    criterion = nn.CrossEntropyLoss()

    model = model.to(device)
    criterion = criterion.to(device)

    if config.do_train:
        train(config.epoch_num, model, train_dataloader, dev_dataloader,
              optimizer, criterion, ['0', '1', '2', '3', '4'], model_file,
              config.log_dir, config.print_step, config.clip, device)

    model.load_state_dict(torch.load(model_file))

    test_loss, test_acc, test_report = evaluate(model, dev_dataloader,
                                                criterion,
                                                ['0', '1', '2', '3'], device)
    print("-------------- Test -------------")
    print("\t Loss: {} | Acc: {} | Macro avg F1: {} | Weighted avg F1: {}".
          format(test_loss, test_acc, test_report['macro avg']['f1-score'],
                 test_report['weighted avg']['f1-score']))
Пример #8
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--model_file",
                        default=None,
                        type=str,
                        required=False,
                        help="The model file")
    parser.add_argument("--batch_size",
                        default=None,
                        type=int,
                        required=True,
                        help="Batch size")
    parser.add_argument("--seq_length",
                        default=None,
                        type=int,
                        required=True,
                        help="seq_length")
    parser.add_argument("--model",
                        default=None,
                        type=str,
                        required=True,
                        help="model")
    parser.add_argument("--dataset",
                        default=None,
                        type=str,
                        required=True,
                        help="dataset")
    parser.add_argument("--epochs",
                        default=5,
                        type=int,
                        required=False,
                        help="Epochs")

    args = parser.parse_args()
    global max_seq_length
    max_seq_length = args.seq_length

    global num_train_epochs
    num_train_epochs = args.epochs

    output_dir = 'model-' + args.dataset
    data_path = dataset_map[args.dataset]
    gradient_accumulation_steps = args.batch_size

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    logger.info('gpu count: %d', n_gpu)

    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(random_seed)

    os.makedirs(output_dir, exist_ok=True)
    if (args.model_file):
        model_state_dict = torch.load(args.model_file, map_location=device)
        model = XLNetForMultipleChoice.from_pretrained(
            args.model,
            state_dict=model_state_dict,
            dropout=0,
            summary_last_dropout=0)
    else:
        ## note: dropout rate set to zero, increased accuracy
        model = XLNetForMultipleChoice.from_pretrained(args.model,
                                                       dropout=0,
                                                       summary_last_dropout=0)
    model.to(device)
    tokenizer = XLNetTokenizer.from_pretrained(args.model)

    no_decay = ['bias', 'LayerNorm.weight']
    ## note: no weight decay according to XLNet paper
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]

    ## note: Adam epsilon used 1e-6
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=1e-6)

    train_data = load_and_cache_examples(data_path, args.dataset, args.model,
                                         tokenizer)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=train_batch_size)

    ## note: Used gradient accumulation steps to simulate larger batch size
    num_train_steps = len(
        train_dataloader) // gradient_accumulation_steps * num_train_epochs

    ## note: Warmup proportion of 0.1
    num_warmup_steps = num_train_steps // 10
    logger.info("  Num warmup steps = %d", num_warmup_steps)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_train_steps)

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_data))
    logger.info("  Batch size = %d",
                train_batch_size * gradient_accumulation_steps)
    logger.info("  Num steps = %d", num_train_steps)

    global_step = 0

    for ep in range(int(num_train_epochs)):
        model.train()
        max_score = 0
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch
            output = model(input_ids=input_ids,
                           token_type_ids=segment_ids,
                           attention_mask=input_mask,
                           labels=label_ids)
            loss = output.loss
            if n_gpu > 1:
                loss = loss.mean()
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps
            loss.backward()
            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            if (step + 1) % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()  # We have accumulated enought gradients
                scheduler.step()
                model.zero_grad()
                global_step += 1

            if step % 800 == 0:
                logger.info("Training loss: {}, global step: {}".format(
                    tr_loss / nb_tr_steps, global_step))

        eval_data = load_and_cache_examples(data_path,
                                            args.dataset,
                                            args.model,
                                            tokenizer,
                                            evaluate=True)
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=eval_batch_size)

        logger.info("***** Running Dev Evaluation *****")
        logger.info("  Num examples = %d", len(eval_data))
        logger.info("  Batch size = %d", eval_batch_size)
        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                eval_output = model(input_ids=input_ids,
                                    token_type_ids=segment_ids,
                                    attention_mask=input_mask,
                                    labels=label_ids)
            tmp_eval_loss = eval_output.loss
            logits = eval_output.logits
            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()

            tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1))

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples

        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'global_step': global_step,
            'loss': tr_loss / nb_tr_steps
        }
        logger.info(" Epoch: %d", (ep + 1))
        logger.info("***** Eval results *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))

        output_eval_file = os.path.join(
            output_dir, "{}_{}_{}_results.txt".format(args.dataset, args.model,
                                                      max_seq_length))
        with open(output_eval_file, "a+") as writer:
            writer.write("Epoch: " + str(ep + 1) + "\n")
            for key in sorted(result.keys()):
                writer.write("%s = %s\n" % (key, str(result[key])))
            writer.write("\n")
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(
            output_dir,
            "{}_{}_{}_epoch{}_{}.bin".format(args.dataset, args.model,
                                             max_seq_length, ep + 1,
                                             int(eval_accuracy * 100)))
        torch.save(model_to_save.state_dict(), output_model_file)

    # testdata
    test_data = load_and_cache_examples(data_path,
                                        args.dataset,
                                        args.model,
                                        tokenizer,
                                        test=True)
    test_sampler = SequentialSampler(test_data)
    test_dataloader = DataLoader(test_data,
                                 sampler=test_sampler,
                                 batch_size=eval_batch_size)

    logger.info("***** Running Test Evaluation *****")
    logger.info("  Num examples = %d", len(test_data))
    logger.info("  Batch size = %d", eval_batch_size)
    model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0

    for input_ids, input_mask, segment_ids, label_ids in test_dataloader:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        label_ids = label_ids.to(device)

        with torch.no_grad():
            eval_output = model(input_ids=input_ids,
                                token_type_ids=segment_ids,
                                attention_mask=input_mask,
                                labels=label_ids)
        tmp_eval_loss = eval_output.loss
        logits = eval_output.logits
        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.to('cpu').numpy()

        tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1))

        eval_loss += tmp_eval_loss.mean().item()
        eval_accuracy += tmp_eval_accuracy

        nb_eval_examples += input_ids.size(0)
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = eval_accuracy / nb_eval_examples

    result = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy}
    logger.info("***** Eval results *****")
    for key in sorted(result.keys()):
        logger.info("  %s = %s", key, str(result[key]))

    output_eval_file = os.path.join(
        output_dir, "{}_{}_{}_results.txt".format(args.dataset, args.model,
                                                  max_seq_length))
    with open(output_eval_file, "a+") as writer:
        writer.write("Test:\n")
        for key in sorted(result.keys()):
            writer.write("%s = %s\n" % (key, str(result[key])))