Пример #1
0
def main():
    #rs_writer = SummaryWriter("./log")
    parser = argparse.ArgumentParser()
    viz = visdom.Visdom()

    # Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        choices=[
                            "semeval_QA_EXPT", "semeval_QA_T",
                            "travel_experience", "semeval_single"
                        ],
                        help="Name of the task to train")
    parser.add_argument("--vocab_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The path of BERT pre-trained vocab file")
    parser.add_argument("--init_checkpoint",
                        default=None,
                        type=str,
                        required=True,
                        help="The path of BERT pre-trained .ckpt file")
    parser.add_argument("--bert_config_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The path of BERT .json file")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory of training result")
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The path of training dataset")

    # Other parameters
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="The size of training batch")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="The size of evaluation batch")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum sentence length of input after WordPiece tonkenization\n"
        "Greater than the max will be truncated, smaller than the max will be padding"
    )
    parser.add_argument("--local_rank",
                        default=-1,
                        type=int,
                        help="Local_rank for distributed training on gpus")
    parser.add_argument("--seed",
                        default=42,
                        type=int,
                        help="Random seed for initialization")
    parser.add_argument(
        "--accumulate_gradients",
        default=1,
        type=int,
        help=
        "Number of steps to accumulate gradient on (divide the batch_size and accumulate)"
    )
    parser.add_argument(
        '--gradient_accumulation_steps',
        default=1,
        type=int,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    parser.add_argument('--save_steps',
                        type=int,
                        default=100,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        '--layers',
        type=int,
        nargs='+',
        default=[-2],
        help="choose the layers that used for downstream tasks, "
        "-2 means use pooled output, -1 means all layer,"
        "else means the detail layers. default is -2")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="The number of epochs on training data")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The learning rate of model")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for\n"
        "0.1 means 10% of training set")
    parser.add_argument('--layer_learning_rate_decay',
                        type=float,
                        default=0.95)
    parser.add_argument('--layer_learning_rate',
                        type=float,
                        nargs='+',
                        default=[2e-5] * 12,
                        help="learning rate in each group")
    parser.add_argument("--do_train",
                        default=False,
                        action="store_true",
                        help="Whether training the data or not")
    parser.add_argument("--do_eval",
                        default=False,
                        action="store_true",
                        help="Whether evaluating the data or not")
    parser.add_argument("--do_predict",
                        default=False,
                        action="store_true",
                        help="Whether predicting the data or not")
    parser.add_argument(
        "--do_lower_case",
        default=False,
        action="store_true",
        help=
        "To lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--no_cuda",
                        default=False,
                        action="store_true",
                        help="Whether use the GPU device or not")
    parser.add_argument("--discr",
                        default=False,
                        action='store_true',
                        help="Whether to do discriminative fine-tuning.")
    parser.add_argument('--pooling_type',
                        default=None,
                        type=str,
                        choices=[None, 'mean', 'max'])

    args = parser.parse_args()

    viz = visdom.Visdom()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device %s n_gpu %d distributed training %r", device, n_gpu,
                bool(args.local_rank != -1))

    if args.accumulate_gradients < 1:
        raise ValueError(
            "Invalid accumulate_gradients parameter: {}, should be >= 1".
            format(args.accumulate_gradients))

    args.train_batch_size = int(args.train_batch_size /
                                args.accumulate_gradients)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    bert_config = BertConfig.from_json_file(args.bert_config_file)

    if args.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}"
            .format(args.max_seq_length, bert_config.max_position_embeddings))

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    # prepare dataloaders
    processors = {
        "semeval_QA_EXPT": Semeval_QA_EXPT_Processor,
        "semeval_QA_T": Semeval_QA_T_Processor,
        "travel_experience": Travel_exp_data,
        "semeval_single": Semeval_single_Processor
    }

    processor = processors[args.task_name]()
    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file,
                                           do_lower_case=args.do_lower_case)

    # training set
    train_examples = processor.get_train_examples(args.data_dir)
    num_train_steps = int(
        len(train_examples) / args.train_batch_size * args.num_train_epochs)

    train_features = convert_examples_to_features(train_examples, label_list,
                                                  args.max_seq_length,
                                                  tokenizer)
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_examples))
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Num steps = %d", num_train_steps)

    all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                   dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in train_features],
                                 dtype=torch.long)

    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                               all_label_ids)
    if args.local_rank == -1:
        train_sampler = RandomSampler(train_data)
    else:
        train_sampler = DistributedSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    # test set
    if args.do_eval:
        test_examples = processor.get_test_examples(args.data_dir)
        test_features = convert_examples_to_features(test_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)

        all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in test_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in test_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in test_features],
                                     dtype=torch.long)

        test_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        test_dataloader = DataLoader(test_data,
                                     batch_size=args.eval_batch_size,
                                     shuffle=False)

    # model and optimizer layer version
    """
    model = BertForSequenceClassification(bert_config, len(label_list), args.layers, pooling=args.pooling_type)
    if args.init_checkpoint is not None:
        model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
    model.to(device)

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    no_decay = ['bias', 'gamma', 'beta']
    if args.discr:
        if len(args.layer_learning_rate) > 1:
            groups = [(f'layer.{i}.', args.layer_learning_rate[i]) for i in range(12)]
        else:
            lr = args.layer_learning_rate[0]
            groups = [(f'layer.{i}.', lr * pow(args.layer_learning_rate_decay, 11 - i)) for i in range(12)]
        group_all = [f'layer.{i}.' for i in range(12)]
        no_decay_optimizer_parameters = []
        decay_optimizer_parameters = []
        for g, l in groups:
            no_decay_optimizer_parameters.append(
                {
                    'params': [p for n, p in model.named_parameters() if
                               not any(nd in n for nd in no_decay) and any(nd in n for nd in [g])],
                    'weight_decay_rate': 0.01, 'lr': l
                }
            )
            decay_optimizer_parameters.append(
                {
                    'params': [p for n, p in model.named_parameters() if
                               any(nd in n for nd in no_decay) and any(nd in n for nd in [g])],
                    'weight_decay_rate': 0.0, 'lr': l
                }
            )

        group_all_parameters = [
            {'params': [p for n, p in model.named_parameters() if
                        not any(nd in n for nd in no_decay) and not any(nd in n for nd in group_all)],
             'weight_decay_rate': 0.01},
            {'params': [p for n, p in model.named_parameters() if
                        any(nd in n for nd in no_decay) and not any(nd in n for nd in group_all)],
             'weight_decay_rate': 0.0},
        ]
        optimizer_parameters = no_decay_optimizer_parameters + decay_optimizer_parameters + group_all_parameters

    else:
        optimizer_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.01},
            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.0}
        ]

    optimizer = BERTAdam(optimizer_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=num_train_steps)
    """

    model = BertForSequenceClassification(bert_config, len(label_list))
    if args.init_checkpoint is not None:
        model.bert.load_state_dict(
            torch.load(args.init_checkpoint, map_location='cpu'))
    model.to(device)

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    no_decay = ['bias', 'gamma', 'beta']
    optimizer_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay_rate':
        0.01
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay_rate':
        0.0
    }]

    optimizer = BERTAdam(optimizer_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=num_train_steps)

    # train
    output_log_file = os.path.join(args.output_dir, "log.txt")
    print("output_log_file=", output_log_file)
    with open(output_log_file, "w") as writer:
        if args.do_eval:
            writer.write(
                "epoch\tglobal_step\tloss\ttest_loss\ttest_accuracy\n")
        else:
            writer.write("epoch\tglobal_step\tloss\n")

    global_step = 0
    epoch = 0
    best_epoch, best_accuracy = 0, 0
    for _ in trange(int(args.num_train_epochs), desc="Epoch"):
        epoch += 1
        model.train()
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch
            loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
            print(loss.item())
            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss.backward()
            tr_loss += loss.item()
            viz.line([loss.item()], [global_step],
                     win='tr_loss',
                     update='append')
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()  # We have accumulated enought gradients
                model.zero_grad()
                global_step += 1

                # Save the checkpoint model after each N steps
                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    save_output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(save_output_dir):
                        os.makedirs(save_output_dir)
                    torch.save(
                        model.state_dict(),
                        os.path.join(save_output_dir, "training_args.bin"))
                viz.line([optimizer.get_lr()[0]], [global_step - 1],
                         win="lr",
                         update="append")

        # eval_test
        if args.do_eval:
            model.eval()
            test_loss, test_accuracy = 0, 0
            nb_test_steps, nb_test_examples = 0, 0
            with open(
                    os.path.join(args.output_dir,
                                 "test_ep_" + str(epoch) + ".txt"),
                    "w") as f_test:
                for input_ids, input_mask, segment_ids, label_ids in tqdm(
                        test_dataloader, desc="Testing"):
                    input_ids = input_ids.to(device)
                    input_mask = input_mask.to(device)
                    segment_ids = segment_ids.to(device)
                    label_ids = label_ids.to(device)

                    with torch.no_grad():
                        tmp_test_loss, logits = model(input_ids, segment_ids,
                                                      input_mask, label_ids)

                    logits = F.softmax(logits, dim=-1)
                    logits = logits.detach().cpu().numpy()
                    label_ids = label_ids.to('cpu').numpy()
                    outputs = np.argmax(logits, axis=1)
                    for output_i in range(len(outputs)):
                        f_test.write(str(outputs[output_i]))
                        for ou in logits[output_i]:
                            f_test.write(" " + str(ou))
                        f_test.write("\n")
                    tmp_test_accuracy = np.sum(outputs == label_ids)
                    viz.line([tmp_test_loss.item()], [nb_test_steps],
                             win='eval_loss',
                             update='append')
                    test_loss += tmp_test_loss.mean().item()
                    test_accuracy += tmp_test_accuracy

                    nb_test_examples += input_ids.size(0)
                    nb_test_steps += 1

            test_loss = test_loss / nb_test_steps
            test_accuracy = test_accuracy / nb_test_examples
            viz.line([test_accuracy], [nb_test_steps - 1],
                     win='test_acc',
                     update='append')
            if test_accuracy > best_accuracy:
                best_accuracy = test_accuracy
                best_epoch = epoch
                torch.save(model.state_dict(),
                           os.path.join(args.output_dir, "best_model.bin"))

        result = collections.OrderedDict()
        if args.do_eval:
            result = {
                'epoch': epoch,
                'global_step': global_step,
                'loss': tr_loss / nb_tr_steps,
                'test_loss': test_loss,
                'test_accuracy': test_accuracy
            }
        else:
            result = {
                'epoch': epoch,
                'global_step': global_step,
                'loss': tr_loss / nb_tr_steps
            }

        logger.info("***** Eval results *****")
        with open(output_log_file, "a+") as writer:
            for key in result.keys():
                logger.info("  %s = %s\n", key, str(result[key]))
                writer.write("%s\t" % (str(result[key])))
            writer.write("\n")
        print("The best Epoch is: ", best_epoch)
        print("The best test_accuracy is: ", best_accuracy)
Пример #2
0
def main():
    parser = argparse.ArgumentParser()
    BERT_DIR = "model/uncased_L-12_H-768_A-12/"
    ## Required parameters
    parser.add_argument("--bert_config_file", default=BERT_DIR+"bert_config.json", \
                        type=str, help="The config json file corresponding to the pre-trained BERT model. "
                             "This specifies the model architecture.")
    parser.add_argument("--vocab_file", default=BERT_DIR+"vocab.txt", type=str, \
                        help="The vocabulary file that the BERT model was trained on.")
    parser.add_argument("--output_dir", default="out", type=str, \
                        help="The output directory where the model checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--train_file", type=str, \
                        help="SQuAD json for training. E.g., train-v1.1.json", \
                        default="")
    parser.add_argument("--predict_file", type=str,
                        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json", \
                        default="")
    parser.add_argument("--init_checkpoint", type=str,
                        help="Initial checkpoint (usually from a pre-trained BERT model).", \
                        default=BERT_DIR+"pytorch_model.bin")
    parser.add_argument(
        "--do_lower_case",
        default=True,
        action='store_true',
        help="Whether to lower case the input text. Should be True for uncased "
        "models and False for cased models.")
    parser.add_argument(
        "--max_seq_length",
        default=300,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=128,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=10.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
        "of training.")
    parser.add_argument("--save_checkpoints_steps",
                        default=1000,
                        type=int,
                        help="How often to save the model checkpoint.")
    parser.add_argument("--iterations_per_loop",
                        default=1000,
                        type=int,
                        help="How many steps to make in each estimator call.")
    parser.add_argument(
        "--n_best_size",
        default=3,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json "
        "output file.")
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")

    parser.add_argument(
        "--verbose_logging",
        default=False,
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        "--accumulate_gradients",
        type=int,
        default=1,
        help=
        "Number of steps to accumulate gradient on (divide the batch_size and accumulate)"
    )
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    parser.add_argument('--eval_period', type=int, default=2000)

    parser.add_argument('--max_n_answers', type=int, default=5)
    parser.add_argument('--merge_query', type=int, default=-1)
    parser.add_argument('--reduce_layers', type=int, default=-1)
    parser.add_argument('--reduce_layers_to_tune', type=int, default=-1)

    parser.add_argument('--only_comp', action="store_true", default=False)

    parser.add_argument('--train_subqueries_file', type=str, default="")  #500
    parser.add_argument('--predict_subqueries_file', type=str,
                        default="")  #500
    parser.add_argument('--prefix', type=str, default="")  #500

    parser.add_argument('--model', type=str, default="qa")  #500
    parser.add_argument('--pooling', type=str, default="max")
    parser.add_argument('--debug', action="store_true", default=False)
    parser.add_argument('--output_dropout_prob', type=float, default=0)
    parser.add_argument('--wait_step', type=int, default=30)
    parser.add_argument('--with_key', action="store_true", default=False)
    parser.add_argument('--add_noise', action="store_true", default=False)

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device %s n_gpu %d distributed training %r", device, n_gpu,
                bool(args.local_rank != -1))

    if args.accumulate_gradients < 1:
        raise ValueError(
            "Invalid accumulate_gradients parameter: {}, should be >= 1".
            format(args.accumulate_gradients))

    args.train_batch_size = int(args.train_batch_size /
                                args.accumulate_gradients)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
        if not args.predict_file:
            raise ValueError(
                "If `do_train` is True, then `predict_file` must be specified."
            )

    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )

    bert_config = BertConfig.from_json_file(args.bert_config_file)

    if args.do_train and args.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (args.max_seq_length, bert_config.max_position_embeddings))

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        logger.info("Output directory () already exists and is not empty.")
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)

    tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file,
                                           do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_steps = None

    eval_dataloader, eval_examples, eval_features, _ = get_dataloader(
        logger=logger,
        args=args,
        input_file=args.predict_file,
        subqueries_file=args.predict_subqueries_file,
        is_training=False,
        batch_size=args.predict_batch_size,
        num_epochs=1,
        tokenizer=tokenizer)
    if args.do_train:
        train_dataloader, train_examples, _, num_train_steps = get_dataloader(
                logger=logger, args=args, \
                input_file=args.train_file, \
                subqueries_file=args.train_subqueries_file, \
                is_training=True,
                batch_size=args.train_batch_size,
                num_epochs=args.num_train_epochs,
                tokenizer=tokenizer)

    #a = input()
    if args.model == 'qa':
        model = BertForQuestionAnswering(bert_config, 4)
        metric_name = "F1"
    elif args.model == 'classifier':
        if args.reduce_layers != -1:
            bert_config.num_hidden_layers = args.reduce_layers
        model = BertClassifier(bert_config, 2, args.pooling)
        metric_name = "F1"
    elif args.model == "span-predictor":
        if args.reduce_layers != -1:
            bert_config.num_hidden_layers = args.reduce_layers
        if args.with_key:
            Model = BertForQuestionAnsweringWithKeyword
        else:
            Model = BertForQuestionAnswering
        model = Model(bert_config, 2)
        metric_name = "Accuracy"
    else:
        raise NotImplementedError()

    if args.init_checkpoint is not None and args.do_predict and \
                len(args.init_checkpoint.split(','))>1:
        assert args.model == "qa"
        model = [model]
        for i, checkpoint in enumerate(args.init_checkpoint.split(',')):
            if i > 0:
                model.append(BertForQuestionAnswering(bert_config, 4))
            print("Loading from", checkpoint)
            state_dict = torch.load(checkpoint, map_location='cpu')
            filter = lambda x: x[7:] if x.startswith('module.') else x
            state_dict = {filter(k): v for (k, v) in state_dict.items()}
            model[-1].load_state_dict(state_dict)
            model[-1].to(device)

    else:
        if args.init_checkpoint is not None:
            print("Loading from", args.init_checkpoint)
            state_dict = torch.load(args.init_checkpoint, map_location='cpu')
            if args.reduce_layers != -1:
                state_dict = {k:v for k, v in state_dict.items() \
                    if not '.'.join(k.split('.')[:3]) in \
                    ['encoder.layer.{}'.format(i) for i in range(args.reduce_layers, 12)]}
            if args.do_predict:
                filter = lambda x: x[7:] if x.startswith('module.') else x
                state_dict = {filter(k): v for (k, v) in state_dict.items()}
                model.load_state_dict(state_dict)
            else:
                model.bert.load_state_dict(state_dict)
                if args.reduce_layers_to_tune != -1:
                    model.bert.embeddings.required_grad = False
                    n_layers = 12 if args.reduce_layers == -1 else args.reduce_layers
                    for i in range(n_layers - args.reduce_layers_to_tune):
                        model.bert.encoder.layer[i].require_grad = False

        model.to(device)

        if args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[args.local_rank],
                output_device=args.local_rank)
        elif n_gpu > 1:
            model = torch.nn.DataParallel(model)

    if args.do_train:
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_parameters = [{
            'params':
            [p for n, p in model.named_parameters() if n not in no_decay],
            'weight_decay_rate':
            0.01
        }, {
            'params':
            [p for n, p in model.named_parameters() if n in no_decay],
            'weight_decay_rate':
            0.0
        }]

        optimizer = BERTAdam(optimizer_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps)

        global_step = 0

        best_f1 = 0
        wait_step = 0
        model.train()
        global_step = 0
        stop_training = False

        for epoch in range(int(args.num_train_epochs)):
            tr_loss = 0
            tqdm_bar = tqdm(train_dataloader,
                            desc="Training",
                            disable=args.local_rank not in [-1, 0])
            for step, batch in enumerate(tqdm_bar):
                global_step += 1
                batch = [t.to(device) for t in batch]
                loss = model(batch, global_step)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                loss.backward()
                tr_loss += loss.item()
                if global_step % args.gradient_accumulation_steps == 0:
                    optimizer.step()  # We have accumulated enought gradients
                    model.zero_grad()
                    tqdm_bar.desc = "Epoch training loss: {:.2e} lr: {:.2e}".format(
                        tr_loss / (step + 1),
                        optimizer.get_lr()[0])
                if global_step % args.eval_period == 0:
                    model.eval()
                    f1 =  predict(args, model, eval_dataloader, eval_examples, eval_features, \
                                  device, write_prediction=False)
                    logger.info("%s: %.3f on epoch=%d" %
                                (metric_name, f1 * 100.0, epoch))
                    if best_f1 < f1:
                        logger.info("Saving model with best %s: %.3f -> %.3f on epoch=%d" % \
                                (metric_name, best_f1*100.0, f1*100.0, epoch))
                        model_state_dict = {
                            k: v.cpu()
                            for (k, v) in model.state_dict().items()
                        }
                        torch.save(
                            model_state_dict,
                            os.path.join(args.output_dir, "best-model.pt"))
                        model = model.cuda()
                        best_f1 = f1
                        wait_step = 0
                        stop_training = False
                    else:
                        wait_step += 1
                        if best_f1 > 0.1 and wait_step == args.wait_step:
                            stop_training = True
                    model.train()
            logger.info("Training loss %.5f (epoch=%d)" % (tr_loss /
                                                           (step + 1), epoch))
            logger.info("Best %s: %.3f up to epoch=%d" %
                        (metric_name, best_f1 * 100.0, epoch))
            if stop_training:
                break

    elif args.do_predict:
        if type(model) == list:
            model = [m.eval() for m in model]
        else:
            model.eval()
        f1 = predict(args, model, eval_dataloader, eval_examples,
                     eval_features, device)
        logger.info("Final %s score: %.3f%%" % (metric_name, f1 * 100.0))
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_config_file",
        default=None,
        type=str,
        required=True,
        help=
        "The config json file corresponding to the pre-trained BERT model. \n"
        "This specifies the model architecture.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--vocab_file",
        default=None,
        type=str,
        required=True,
        help="The vocabulary file that the BERT model was trained on.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--init_checkpoint",
        default=None,
        type=str,
        help="Initial checkpoint (usually from a pre-trained BERT model).")
    parser.add_argument(
        "--do_lower_case",
        default=False,
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--save_checkpoints_steps",
                        default=1000,
                        type=int,
                        help="How often to save the model checkpoint.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--accumulate_gradients",
        type=int,
        default=1,
        help=
        "Number of steps to accumulate gradient on (divide the batch_size and accumulate)"
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    args = parser.parse_args()

    processors = {
        "ag": AGNewsProcessor,
        "ag_sep": AGNewsProcessor_sep,
        "ag_sep_aug": AGNewsProcessor_sep_aug,
        "imdb": IMDBProcessor,
        "imdb_sep": IMDBProcessor_sep,
        "imdb_sep_aug": IMDBProcessor_sep_aug,
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device %s n_gpu %d distributed training %r", device, n_gpu,
                bool(args.local_rank != -1))

    if args.accumulate_gradients < 1:
        raise ValueError(
            "Invalid accumulate_gradients parameter: {}, should be >= 1".
            format(args.accumulate_gradients))

    args.train_batch_size = int(args.train_batch_size /
                                args.accumulate_gradients)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    bert_config = BertConfig.from_json_file(args.bert_config_file)

    if args.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}"
            .format(args.max_seq_length, bert_config.max_position_embeddings))

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file,
                                           do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_steps = int(
            len(train_examples) / args.train_batch_size *
            args.num_train_epochs)

    model = BertForSequenceClassification(bert_config, len(label_list))
    if args.init_checkpoint is not None:
        model.bert.load_state_dict(
            torch.load(args.init_checkpoint, map_location='cpu'))
    model.to(device)

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    no_decay = ['bias', 'gamma', 'beta']
    optimizer_parameters = [
        {
            'params': [
                p for n, p in model.named_parameters()
                if n not in no_decay and n[:26] != "module.bert.encoder.layer."
            ],
            'weight_decay_rate':
            0.01
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n in no_decay and n[:26] != "module.bert.encoder.layer."
            ],
            'weight_decay_rate':
            0.0
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 0 and n not in no_decay
            ],
            'lr':
            args.learning_rate / 2048,
            'weight_decay_rate':
            0.01
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 1 and n not in no_decay
            ],
            'lr':
            args.learning_rate / 1024,
            'weight_decay_rate':
            0.01
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 2 and n not in no_decay
            ],
            'lr':
            args.learning_rate / 512,
            'weight_decay_rate':
            0.01
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 3 and n not in no_decay
            ],
            'lr':
            args.learning_rate / 256,
            'weight_decay_rate':
            0.01
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 4 and n not in no_decay
            ],
            'lr':
            args.learning_rate / 128,
            'weight_decay_rate':
            0.01
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 5 and n not in no_decay
            ],
            'lr':
            args.learning_rate / 64,
            'weight_decay_rate':
            0.01
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 6 and n not in no_decay
            ],
            'lr':
            args.learning_rate / 32,
            'weight_decay_rate':
            0.01
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 7 and n not in no_decay
            ],
            'lr':
            args.learning_rate / 16,
            'weight_decay_rate':
            0.01
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 8 and n not in no_decay
            ],
            'lr':
            args.learning_rate / 8,
            'weight_decay_rate':
            0.01
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 9 and n not in no_decay
            ],
            'lr':
            args.learning_rate / 4,
            'weight_decay_rate':
            0.01
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[28] == "."
                and int(n[26:28]) == 10 and n not in no_decay
            ],
            'lr':
            args.learning_rate / 2,
            'weight_decay_rate':
            0.01
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 0 and n in no_decay
            ],
            'lr':
            args.learning_rate / 2048,
            'weight_decay_rate':
            0.0
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 1 and n in no_decay
            ],
            'lr':
            args.learning_rate / 1024,
            'weight_decay_rate':
            0.0
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 2 and n in no_decay
            ],
            'lr':
            args.learning_rate / 512,
            'weight_decay_rate':
            0.0
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 3 and n in no_decay
            ],
            'lr':
            args.learning_rate / 256,
            'weight_decay_rate':
            0.0
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 4 and n in no_decay
            ],
            'lr':
            args.learning_rate / 128,
            'weight_decay_rate':
            0.0
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 5 and n in no_decay
            ],
            'lr':
            args.learning_rate / 64,
            'weight_decay_rate':
            0.0
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 6 and n in no_decay
            ],
            'lr':
            args.learning_rate / 32,
            'weight_decay_rate':
            0.0
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 7 and n in no_decay
            ],
            'lr':
            args.learning_rate / 16,
            'weight_decay_rate':
            0.0
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 8 and n in no_decay
            ],
            'lr':
            args.learning_rate / 8,
            'weight_decay_rate':
            0.0
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[27] == "."
                and int(n[26]) == 9 and n in no_decay
            ],
            'lr':
            args.learning_rate / 4,
            'weight_decay_rate':
            0.0
        },
        {
            'params': [
                p for n, p in model.named_parameters()
                if n[:26] == "module.bert.encoder.layer." and n[28] == "."
                and int(n[26:28]) == 10 and n in no_decay
            ],
            'lr':
            args.learning_rate / 2,
            'weight_decay_rate':
            0.0
        },
    ]
    # for n, _ in model.named_parameters():
    #     print("n=",n)
    #     print(type(n))
    #     print(str(n))
    #     if n[:26]=="module.bert.encoder.layer." and n[28]==".":print(int(n[26:28]))
    #     if n[:26]=="module.bert.encoder.layer." and n[27]==".":print(int(n[26]))

    optimizer = BERTAdam(optimizer_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=num_train_steps)

    print("init=", optimizer.get_lr())
    global_step = 0
    eval_examples = processor.get_dev_examples(args.data_dir)
    eval_features = convert_examples_to_features(eval_examples, label_list,
                                                 args.max_seq_length,
                                                 tokenizer)

    all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                   dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                 dtype=torch.long)

    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                              all_label_ids)
    eval_dataloader = DataLoader(eval_data,
                                 batch_size=args.eval_batch_size,
                                 shuffle=False)

    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)

        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        epoch = 0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            epoch += 1
            model.train()
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                loss.backward()
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()  # We have accumulated enought gradients
                    #     print("middle=",optimizer.get_lr())
                    #     print("len(middle)=",len(optimizer.get_lr()))
                    model.zero_grad()
                    global_step += 1

            model.eval()
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0
            with open(
                    os.path.join(args.output_dir,
                                 "results_ep" + str(epoch) + ".txt"),
                    "w") as f:
                for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                    input_ids = input_ids.to(device)
                    input_mask = input_mask.to(device)
                    segment_ids = segment_ids.to(device)
                    label_ids = label_ids.to(device)

                    with torch.no_grad():
                        tmp_eval_loss, logits = model(input_ids, segment_ids,
                                                      input_mask, label_ids)

                    logits = logits.detach().cpu().numpy()
                    label_ids = label_ids.to('cpu').numpy()
                    outputs = np.argmax(logits, axis=1)
                    for output in outputs:
                        f.write(str(output) + "\n")
                    tmp_eval_accuracy = np.sum(outputs == label_ids)

                    eval_loss += tmp_eval_loss.mean().item()
                    eval_accuracy += tmp_eval_accuracy

                    nb_eval_examples += input_ids.size(0)
                    nb_eval_steps += 1

            eval_loss = eval_loss / nb_eval_steps
            eval_accuracy = eval_accuracy / nb_eval_examples

            result = {
                'eval_loss': eval_loss,
                'eval_accuracy': eval_accuracy,
                'global_step': global_step,
                'loss': tr_loss / nb_tr_steps
            }

            output_eval_file = os.path.join(
                args.output_dir, "eval_results_ep" + str(epoch) + ".txt")
            print("output_eval_file=", output_eval_file)
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))