def test(
    args,
    new_dirs=None,
    dev_as_test=None
):  # Load a trained model that you have fine-tuned (we assume evaluate on cpu)
    processor = data_utils.AscProcessor()
    label_list = processor.get_labels()
    tokenizer = BertTokenizer.from_pretrained(
        modelconfig.MODEL_ARCHIVE_MAP[args.bert_model])
    if dev_as_test:
        data_dir = os.path.join(args.data_dir, 'dev_as_test')
    else:
        data_dir = args.data_dir
    eval_examples = processor.get_test_examples(data_dir)
    eval_features = data_utils.convert_examples_to_features(
        eval_examples, label_list, args.max_seq_length, tokenizer, "asc")

    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)
    all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                 dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                   dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                  dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                 dtype=torch.long)
    eval_data = TensorDataset(all_input_ids, all_segment_ids, all_input_mask,
                              all_label_ids)
    # Run prediction for full data
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    model = torch.load(os.path.join(new_dirs, "model.pt"))
    model.cuda()
    model.eval()

    full_logits = []
    full_label_ids = []
    for step, batch in enumerate(eval_dataloader):
        batch = tuple(t.cuda() for t in batch)
        input_ids, segment_ids, input_mask, label_ids = batch

        with torch.no_grad():
            logits = model(input_ids, segment_ids, input_mask)

        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.cpu().numpy()

        full_logits.extend(logits.tolist())
        full_label_ids.extend(label_ids.tolist())

    output_eval_json = os.path.join(new_dirs, "predictions.json")
    with open(output_eval_json, "w") as fw:
        json.dump({"logits": full_logits, "label_ids": full_label_ids}, fw)
Beispiel #2
0
def evaluate(model, result_file):
    with open(args.predict_example_files, 'rb') as f:
        eval_examples = pickle.load(f)

    with torch.no_grad():
        tokenizer = BertTokenizer.from_pretrained('../roberta_wwm_ext',
                                                  do_lower_case=True)
        model.eval()
        pred_answers, ref_answers = [], []

        for step, example in enumerate(tqdm(eval_examples)):
            start_probs, end_probs = [], []
            question_text = example['question_text']

            for p_num, doc_tokens in enumerate(
                    example['doc_tokens'][:args.max_para_num]):
                (input_ids, input_ids_q, input_mask, input_mask_q, segment_ids) = \
                    predict_data.predict_data(question_text, doc_tokens['doc_tokens'], tokenizer, args.max_seq_length,
                                              args.max_query_length)
                # input_ids = torch.tensor(input_ids).unsqueeze(0)
                # input_mask = torch.tensor(input_mask).unsqueeze(0)
                # segment_ids = torch.tensor(segment_ids).unsqueeze(0)
                start_prob, end_prob, can_logit = model(
                    input_ids,
                    input_ids_q,
                    token_type_ids=segment_ids,
                    attention_mask=input_mask,
                    attention_mask_q=input_mask_q)
                start_probs.append(start_prob.squeeze(0))
                end_probs.append(end_prob.squeeze(0))

            best_answer, docs_index = find_best_answer(example, start_probs,
                                                       end_probs)

            print(best_answer)
            pred_answers.append({
                'question_id': example['id'],
                'question': example['question_text'],
                # 'question_type': example['question_type'],
                'answers': [best_answer],
                'entity_answers': [[]],
                'yesno_answers': []
            })
            if 'answers' in example:
                ref_answers.append({
                    'question_id': example['id'],
                    # 'question_type': example['question_type'],
                    'answers': example['answers'],
                    'entity_answers': [[]],
                    'yesno_answers': []
                })
        with open(result_file, 'w', encoding='utf-8') as fout:
            for pred_answer in pred_answers:
                fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n')
        with open("../metric/ref_dev.json", 'w', encoding='utf-8') as fout:
            for pred_answer in ref_answers:
                fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n')
def predict_single_sentence(model_wieght_path, text):
    # initialize basic options
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    output_mode = 'classification'
    processor = PuriProcessor()
    label_list = processor.get_labels()
    num_labels = len(label_list)

    bert_model = 'bert-base-multilingual-uncased'
    tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True)

    # convert text to feature
    example = processor.create_single_example(text)
    feature = convert_single_example_to_feature(example, 128, tokenizer,
                                                output_mode)

    # create_model
    model = BertForSequenceClassification.from_pretrained(
        model_wieght_path, num_labels=num_labels)

    model.eval()
    preds = []

    input_ids = torch.tensor(feature.input_ids, dtype=torch.long).unsqueeze(0)
    input_mask = torch.tensor(feature.input_mask,
                              dtype=torch.long).unsqueeze(0)
    segment_ids = torch.tensor(feature.segment_ids,
                               dtype=torch.long).unsqueeze(0)
    #     label_ids  = feature.label_ids

    with torch.no_grad():
        logits = model(input_ids, segment_ids, input_mask, labels=None)
    if len(preds) == 0:
        preds.append(logits.detach().cpu().numpy())
    else:
        preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)

    preds = preds[0]
    result = np.argmax(preds, axis=1)

    return preds, result
Beispiel #4
0
def infer(input, target_start_id, target_end_id, args):

    sent = input.split(" ")
    assert 0 <= target_start_id and target_start_id < target_end_id and target_end_id <= len(
        sent)
    target = " ".join(sent[target_start_id:target_end_id])

    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

    label_list = ["0", "1"]
    num_labels = len(label_list)
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=True)
    model = BertForSequenceClassification.from_pretrained(
        args.bert_model, num_labels=num_labels)
    model.to(device)

    print(f"input: {input}\ntarget: {target}")
    examples = construct_context_gloss_pairs_through_nltk(
        input, target_start_id, target_end_id)
    eval_features, candidate_results = convert_to_features(examples, tokenizer)
    input_ids = torch.tensor([f.input_ids for f in eval_features],
                             dtype=torch.long)
    input_mask = torch.tensor([f.input_mask for f in eval_features],
                              dtype=torch.long)
    segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                               dtype=torch.long)

    model.eval()
    input_ids = input_ids.to(device)
    input_mask = input_mask.to(device)
    segment_ids = segment_ids.to(device)
    with torch.no_grad():
        logits = model(input_ids=input_ids,
                       token_type_ids=segment_ids,
                       attention_mask=input_mask,
                       labels=None)
    logits_ = F.softmax(logits, dim=-1)
    logits_ = logits_.detach().cpu().numpy()
    output = np.argmax(logits_, axis=0)[1]
    print(f"results:\ngloss: {candidate_results[output][1]}")
Beispiel #5
0
def main(args):
    args.data_dir = os.path.join(args.data_dir, args.task_name)
    args.output_dir = os.path.join(args.output_dir, args.task_name)
    logger.info("args = %s", args)

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "sst-2": Sst2Processor,
        "sts-b": StsbProcessor,
        "qqp": QqpProcessor,
        "qnli": QnliProcessor,
        "rte": RteProcessor,
        "wnli": WnliProcessor,
        "emo": EmoProcessor,
    }

    output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification",
        "wnli": "classification",
        "emo": "classification"
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)

        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

        # device = torch.device('cpu')

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        logger.info("Output directory already exists and is not empty.")
    if not os.path.exists(args.output_dir):
        try:
            os.makedirs(args.output_dir)
        except:
            pass
            logger.info("catch a error")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # tokenizer = BertTokenizer.from_pretrained(args.vocab_file, do_lower_case=args.do_lower_case)
    tokenizer = BertTokenizer.from_pretrained(args.bert_model)

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(
            args.local_rank))

    # use bert to aug train_examples
    ori_train_examples = processor.get_train_examples(args.data_dir)
    eval_examples = processor.get_dev_examples(args.data_dir)
    test_examples = processor.get_test_examples(args.data_dir)

    num_train_optimization_steps = int(
        len(ori_train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps) * args.num_train_epochs

    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    if args.use_saved == 1:
        bert_saved_dir = args.ckpt
        model = BertForNSPAug.from_pretrained(bert_saved_dir,
                                              cache_dir=args.ckpt_cache_dir,
                                              num_labels=num_labels,
                                              args=args)
    else:
        model = BertForNSPAug.from_pretrained(args.bert_model,
                                              cache_dir=cache_dir,
                                              num_labels=num_labels,
                                              args=args)
    model.cuda()
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if args.do_train:
        # Prepare optimizer
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

        global_step = 0
        best_val_acc = 0.0
        first_time = time.time()

        logger.info(
            "*********************************** Running training ***********************************"
        )
        logger.info("  Num original examples = %d", len(ori_train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        model.train()
        aug_ratio = 0.0
        # aug_ratio = 0.2
        aug_seed = np.random.randint(0, 1000)
        for epoch in range(int(args.num_train_epochs)):
            logger.info("epoch=%d,  aug_ratio = %f,  aug_seed=%d", epoch,
                        aug_ratio, aug_seed)
            train_examples = Aug_each_ckpt(ori_train_examples,
                                           label_list,
                                           model,
                                           tokenizer,
                                           args=args,
                                           num_show=args.num_show,
                                           output_mode=output_mode,
                                           seed=aug_seed,
                                           aug_ratio=aug_ratio,
                                           use_bert=False)
            if aug_ratio + args.aug_ratio_each < 1.0:
                aug_ratio += args.aug_ratio_each
            aug_seed += 1

            train_features = convert_examples_to_features(
                train_examples,
                label_list,
                args.max_seq_length,
                tokenizer,
                num_show=args.num_show,
                output_mode=output_mode,
                args=args)
            logger.info(
                "*********************************** Done convert features ***********************************"
            )
            all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                         dtype=torch.long)
            all_input_mask = torch.tensor(
                [f.input_mask for f in train_features], dtype=torch.long)
            all_segment_ids = torch.tensor(
                [f.segment_ids for f in train_features], dtype=torch.long)
            if output_mode == "classification":
                all_label_ids = torch.tensor(
                    [f.label_id for f in train_features], dtype=torch.long)
            elif output_mode == "regression":
                all_label_ids = torch.tensor(
                    [f.label_id for f in train_features], dtype=torch.float)

            token_real_label = torch.tensor(
                [f.token_real_label for f in train_features], dtype=torch.long)
            train_data = TensorDataset(all_input_ids, all_input_mask,
                                       all_segment_ids, all_label_ids,
                                       token_real_label)
            if args.local_rank == -1:
                train_sampler = RandomSampler(train_data)
            else:
                train_sampler = DistributedSampler(train_data)
            train_dataloader = DataLoader(train_data,
                                          sampler=train_sampler,
                                          batch_size=args.train_batch_size)

            logger.info(
                "*********************************** begin training ***********************************"
            )
            tr_loss, tr_seq_loss, tr_aug_loss, train_seq_accuracy, train_aug_accuracy = 0, 0, 0, 0, 0
            nb_tr_examples, nb_tr_steps, nb_tr_tokens = 0, 0, 0
            preds = []
            all_labels = []
            for step, batch in enumerate(train_dataloader):
                batch = tuple(t.cuda() for t in batch)
                input_ids, input_mask, segment_ids, label_ids, token_real_label = batch
                seq_logits, aug_logits, aug_loss = model(
                    input_ids,
                    segment_ids,
                    input_mask,
                    labels=None,
                    token_real_label=token_real_label)
                if output_mode == "classification":
                    # if task_name == "emo":
                    #     loss_fct =
                    # else:
                    loss_fct = CrossEntropyLoss()
                    seq_loss = loss_fct(seq_logits.view(-1, num_labels),
                                        label_ids.view(-1))
                    # print("[classification]label_ids: {}, size: {}".format(label_ids.view(-1), label_ids.view(-1).size()))
                    # print("[classification]seq_logits size: {}".format(seq_logits.view(-1, num_labels).size()))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    seq_loss = loss_fct(seq_logits.view(-1),
                                        label_ids.view(-1))

                token_real_label = token_real_label.detach().cpu().numpy()

                w = args.aug_loss_weight
                loss = (1 - w) * seq_loss + w * aug_loss

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                total_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), 10000.0)

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                batch_loss = seq_loss.mean().item()
                tr_seq_loss += seq_loss.mean().item()
                seq_logits = seq_logits.detach().cpu().numpy()
                label_ids = label_ids.detach().cpu().numpy()
                if len(preds) == 0:
                    preds.append(seq_logits)
                    all_labels.append(label_ids)
                else:
                    preds[0] = np.append(preds[0], seq_logits, axis=0)
                    all_labels[0] = np.append(all_labels[0], label_ids, axis=0)

                aug_logits = aug_logits.detach().cpu().numpy()
                tmp_train_aug_accuracy, tmp_tokens = accuracy(aug_logits,
                                                              token_real_label,
                                                              type="aug")
                train_aug_accuracy += tmp_train_aug_accuracy
                nb_tr_tokens += tmp_tokens
                tr_aug_loss += aug_loss.mean().item()

                if global_step % 20 == 0:
                    loss = tr_loss / nb_tr_steps
                    seq_loss = tr_seq_loss / nb_tr_steps
                    aug_loss = tr_aug_loss / nb_tr_steps
                    tmp_pred = preds[0]
                    tmp_labels = all_labels[0]
                    if output_mode == "classification":
                        tmp_pred = np.argmax(tmp_pred, axis=1)
                    elif output_mode == "regression":
                        tmp_pred = np.squeeze(tmp_pred)
                    res = accuracy(tmp_pred, tmp_labels, task_name=task_name)

                    if nb_tr_tokens != 0:
                        aug_avg = train_aug_accuracy / nb_tr_tokens
                    else:
                        aug_avg = 0.0
                    log_string = ""
                    log_string += "epoch={:<5d}".format(epoch)
                    log_string += " step={:<9d}".format(global_step)
                    log_string += " total_loss={:<9.7f}".format(loss)
                    log_string += " seq_loss={:<9.7f}".format(seq_loss)
                    log_string += " aug_loss={:<9.7f}".format(aug_loss)
                    log_string += " batch_loss={:<9.7f}".format(batch_loss)
                    log_string += " lr={:<9.7f}".format(optimizer.get_lr()[0])
                    log_string += " |g|={:<9.7f}".format(total_norm)
                    #log_string += " tr_seq_acc={:<9.7f}".format(seq_avg)
                    log_string += " tr_aug_acc={:<9.7f}".format(aug_avg)
                    log_string += " mins={:<9.2f}".format(
                        float(time.time() - first_time) / 60)
                    for key in sorted(res.keys()):
                        log_string += "  " + key + "= " + str(res[key])
                    logger.info(log_string)

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            train_loss = tr_loss / nb_tr_steps

            logger.info(
                "*********************************** training epoch done ***********************************"
            )

            if args.do_eval and (args.local_rank == -1
                                 or torch.distributed.get_rank()
                                 == 0) and epoch % 1 == 0:
                tot_time = float(time.time() - first_time) / 60
                eval_loss, eval_seq_loss, eval_aug_loss, eval_res, eval_aug_accuracy, res_parts=\
                 do_evaluate(args, processor, label_list, tokenizer, model, epoch, output_mode, num_labels, task_name, eval_examples, type="dev")

                eval_res["tot_time"] = tot_time
                if "acc" in eval_res:
                    tmp_acc = eval_res["acc"]
                elif "mcc" in eval_res:
                    tmp_acc = eval_res["mcc"]
                else:
                    tmp_acc = eval_res["corr"]

                result = {
                    'eval_total_loss': eval_loss,
                    'eval_seq_loss': eval_seq_loss,
                    'eval_aug_loss': eval_aug_loss,
                    'eval_aug_accuracy': eval_aug_accuracy,
                    'global_step': global_step,
                    'train_loss': train_loss,
                    'train_batch_size': args.train_batch_size,
                    'args': args
                }

                if tmp_acc >= best_val_acc:
                    best_val_acc = tmp_acc
                    dev_test = "dev"
                    result.update({'best_epoch': epoch})

                    model_to_save = model.module if hasattr(
                        model,
                        'module') else model  # Only save the model it-self
                    output_model_dir = os.path.join(args.output_dir,
                                                    "dev_" + str(tmp_acc))
                    if not os.path.exists(output_model_dir):
                        os.makedirs(output_model_dir)
                    output_model_file = os.path.join(output_model_dir,
                                                     WEIGHTS_NAME)
                    torch.save(model_to_save.state_dict(), output_model_file)
                    output_config_file = os.path.join(output_model_dir,
                                                      CONFIG_NAME)
                    with open(output_config_file, 'w') as f:
                        f.write(model_to_save.config.to_json_string())

                result.update(eval_res)
                result.update(res_parts)

                # output_eval_file = os.path.join(args.output_dir,
                # 								dev_test + "_results_" + str(tmp_acc) + ".txt")
                # with open(output_eval_file, "w") as writer:
                logger.info(
                    "****************************** eval results ***********************************"
                )
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    # writer.write("%s = %s\n" % (key, str(result[key])))
            else:
                result = {
                    'eval_total_loss': eval_loss,
                    'eval_seq_loss': eval_seq_loss,
                    'eval_aug_loss': eval_aug_loss,
                    'eval_aug_accuracy': eval_aug_accuracy,
                    'global_step': global_step,
                    'train_loss': train_loss,
                    'train_batch_size': args.train_batch_size,
                    'args': args
                }

                result.update(eval_res)
                result.update(res_parts)
                logger.info(
                    "****************************** eval results ***********************************"
                )
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))

            # write test results
            if args.do_test:
                # res_file = os.path.join(args.output_dir,
                # 							"test_" + str(tmp_acc)+".tsv")

                # idx, preds = do_test(args, label_list, task_name, processor, tokenizer, output_mode, model)

                # dataframe = pd.DataFrame({'index': range(idx), 'prediction': preds})
                # dataframe.to_csv(res_file, index=False, sep='\t')
                # logger.info("  Num test length = %d", idx)
                logger.info(
                    "*********************************** Running test ***********************************"
                )
                logger.info("  Num examples = %d", len(test_examples))
                logger.info("  Batch size = %d", args.eval_batch_size)

                test_loss, test_seq_loss, test_aug_loss, test_res, test_aug_accuracy, res_parts=\
                 do_evaluate(args, processor, label_list, tokenizer, model, epoch, output_mode, num_labels, task_name, test_examples, type="test")
                result = {
                    'test_total_loss': test_loss,
                    'test_seq_loss': test_seq_loss,
                    'test_aug_loss': test_aug_loss,
                    'test_aug_accuracy': test_aug_accuracy,
                    'global_step': global_step,
                    'args': args
                }
                result.update(test_res)
                result.update(res_parts)

                logger.info(
                    "****************************** test results ***********************************"
                )
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))

                logger.info(
                    "*********************************** test done ***********************************"
                )
Beispiel #6
0
seed=123
beam_size=5
length_penalty=0
forbid_ignore_word=None
max_tgt_length=40
batch_size = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
    torch.cuda.manual_seed_all(seed)

tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=False)
data_tokenizer = WhitespaceTokenizer()
tokenizer.max_len = max_seq_length
pair_num_relation = 0
bi_uni_pipeline = []
bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder(list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, max_seq_length, max_tgt_length=max_tgt_length,
                                                                mode="s2s", num_qkv=num_qkv, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False))
amp_handle = None
from apex import amp
amp_handle = amp.init(enable_caching=True)
logger.info("enable fp16 with amp")

cls_num_labels = 2
type_vocab_size = 6
mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
    ["[MASK]", "[SEP]", "[S2S_SOS]"])
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The input train corpus.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--on_memory",
        action='store_true',
        help="Whether to load train samples into memory or use disk")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    #train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        print("Loading Train Dataset", args.train_file)
        train_dataset = BERTDataset(args.train_file,
                                    tokenizer,
                                    seq_len=args.max_seq_length,
                                    corpus_lines=None,
                                    on_memory=args.on_memory)
        num_train_optimization_steps = int(
            len(train_dataset) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    model = BertForPreTraining.from_pretrained(args.bert_model)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            #TODO: check if this works with current data generator from disk that relies on next(file)
            # (it doesn't return item back by index)
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
                loss = model(input_ids, segment_ids, input_mask, lm_label_ids,
                             is_next)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

        # Save a trained model
        logger.info("** ** * Saving fine - tuned model ** ** * ")
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
        if args.do_train:
            torch.save(model_to_save.state_dict(), output_model_file)
Beispiel #8
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--load_finetuned_model',
                        action='store_true',
                        default=False,
                        help="Load finetuned model.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "compq": COMPQProcessor,
    }

    output_modes = {
        "compq": "classification",
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
            args.local_rank))
    if args.load_finetuned_model:
        print("Loading finetuned model....")
        model = BertForSequenceClassification.from_pretrained(
            args.output_dir, num_labels=num_labels)
        tokenizer = BertTokenizer.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)

    else:
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model, num_labels=num_labels)

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0

    if args.do_train:
        # Prepare optimizer
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        if args.fp16:
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=args.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if args.loss_scale == 0:
                optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
            else:
                optimizer = FP16_Optimizer(optimizer,
                                           static_loss_scale=args.loss_scale)
            warmup_linear = WarmupLinearSchedule(
                warmup=args.warmup_proportion,
                t_total=num_train_optimization_steps)

        else:
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 warmup=args.warmup_proportion,
                                 t_total=num_train_optimization_steps)

        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer, output_mode)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.float)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                # define a new function to compute loss values for both output_modes
                logits = model(input_ids, segment_ids, input_mask, labels=None)

                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, num_labels),
                                    label_ids.view(-1))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear.get_lr(
                            global_step, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    if args.do_train and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):
        # Save a trained model, configuration and tokenizer
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        tokenizer.save_vocabulary(args.output_dir)

    model.to(device)

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer, output_mode)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.float)

        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss = 0
        nb_eval_steps = 0
        preds = []
        softmax_preds = []

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)

            # create eval loss and other metric required by the task
            if output_mode == "classification":
                loss_fct = CrossEntropyLoss()
                tmp_eval_loss = loss_fct(logits.view(-1, num_labels),
                                         label_ids.view(-1))
            elif output_mode == "regression":
                loss_fct = MSELoss()
                tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))

            eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
                softmax_preds.append(Softmax(1)(logits).detach().cpu().numpy())
            else:
                preds[0] = np.append(preds[0],
                                     logits.detach().cpu().numpy(),
                                     axis=0)
                softmax_preds[0] = np.append(
                    softmax_preds[0],
                    Softmax(1)(logits).detach().cpu().numpy(),
                    axis=0)

        eval_loss = eval_loss / nb_eval_steps
        preds = preds[0]
        softmax_preds = softmax_preds[0]
        output_prediction_file = os.path.join(args.output_dir,
                                              "predictions.txt")
        with open(output_prediction_file, 'w') as writer:
            for i, pred in enumerate(softmax_preds):
                writer.write(
                    str(pred[0]) + '\t' + str(pred[1]) + '\t' +
                    eval_examples[i].text_a + '\n')

        if output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(task_name, preds, all_label_ids.numpy())
        loss = tr_loss / nb_tr_steps if args.do_train else None

        result['eval_loss'] = eval_loss
        result['global_step'] = global_step
        result['loss'] = loss

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Beispiel #9
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .csv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )
    parser.add_argument("--init_checkpoint",
                        default=None,
                        type=str,
                        required=True,
                        help="The checkpoint file from pretraining")

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--max_steps",
                        default=-1.0,
                        type=float,
                        help="Total number of training steps to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        print(
            "WARNING: Output directory ({}) already exists and is not empty.".
            format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = read_swag_examples(os.path.join(
            args.data_dir, 'train.csv'),
                                            is_training=True)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    model = BertForMultipleChoice.from_pretrained(
        args.bert_model,
        cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE,
                               'distributed_{}'.format(args.local_rank)),
        num_choices=4)
    model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'),
                          strict=False)

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      tokenizer,
                                                      args.max_seq_length,
                                                      True)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor(select_field(train_features, 'input_ids'),
                                     dtype=torch.long)
        all_input_mask = torch.tensor(select_field(train_features,
                                                   'input_mask'),
                                      dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(train_features,
                                                    'segment_ids'),
                                       dtype=torch.long)
        all_label = torch.tensor([f.label for f in train_features],
                                 dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                # Terminate early for benchmarking
                if args.max_steps > 0 and global_step > args.max_steps:
                    break

                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.fp16 and args.loss_scale != 1.0:
                    # rescale loss for fp16 training
                    # see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
                    loss = loss * args.loss_scale
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForMultipleChoice(config, num_choices=4)
        # noinspection PyUnresolvedReferences
        model.load_state_dict(torch.load(output_model_file))
    else:
        model = BertForMultipleChoice.from_pretrained(args.bert_model,
                                                      num_choices=4)
        model.load_state_dict(torch.load(args.init_checkpoint,
                                         map_location='cpu'),
                              strict=False)
    model.to(device)

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = read_swag_examples(os.path.join(
            args.data_dir, 'val.csv'),
                                           is_training=True)
        eval_features = convert_examples_to_features(eval_examples, tokenizer,
                                                     args.max_seq_length, True)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'),
                                     dtype=torch.long)
        all_input_mask = torch.tensor(select_field(eval_features,
                                                   'input_mask'),
                                      dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(eval_features,
                                                    'segment_ids'),
                                       dtype=torch.long)
        all_label = torch.tensor([f.label for f in eval_features],
                                 dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples

        # noinspection PyUnboundLocalVariable
        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'global_step': global_step,
            'loss': tr_loss / nb_tr_steps
        }

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Beispiel #10
0
import sys
import random
import re
from metrics import bleu_metric
import numpy as np
import nltk

from tokenization import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("unilm_v2_bert_pretrain",
                                          do_lower_case=True)

from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))


def move_stop_words(str):
    item = " ".join([w for w in str.split() if not w.lower() in stop_words])
    return item


re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re_art.sub(' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())
Beispiel #11
0
def demo(cfg):
    if not os.path.exists(cfg.ckpt):
        print('Invalid ckpt path:', cfg.ckpt)
        exit(1)
    ckpt = torch.load(cfg.ckpt, map_location=lambda storage, loc: storage)
    print(cfg.ckpt, 'loaded')
    loaded_cfg = ckpt['cfg'].__dict__

    del loaded_cfg['test_set']
    del loaded_cfg['use_pretrain']
    del loaded_cfg['num_workers']
    del loaded_cfg['num_episodes']
    del loaded_cfg['memory_num']
    del loaded_cfg['memory_len']

    cfg.__dict__.update(loaded_cfg)
    cfg.model = cfg.model.upper()
    pprint(cfg.__dict__)

    model = create_a3c_model(cfg)
    model.load_state_dict(ckpt['model'])
    model.cuda()

    tokenizer = BertTokenizer.from_pretrained(cfg.bert_model)

    env = Environment(cfg, cfg.test_set, tokenizer, shuffle=True)
    env.set_model(model)
    env.set_gpu_id(torch.cuda.current_device())
    print(env.dataset.path, 'loaded')
    while True:
        model.eval()
        env.reset()
        print('-' * 80)
        print('Data ID:', env.data_idx)
        print()
        print('[Context]')
        print(' '.join(tokenizer.convert_ids_to_tokens(env.data.ctx_words)))
        print()

        ques = ' '.join(tokenizer.convert_ids_to_tokens(env.data.ques_words))
        ques = ques.replace(' ##', '')
        ques = ques.replace('##', '')
        print('[Question]')
        print(ques)
        print()
        answs = []
        indices = list(set(env.data.indices))
        for indice in indices:
            s_idx = indice[0]
            e_idx = indice[1]
            answ = ' '.join(
                tokenizer.convert_ids_to_tokens(
                    env.data.ctx_words[s_idx:e_idx + 1]))
            answ = answ.replace(' ##', '')
            answ = answ.replace('##', '')
            answs.append(answ)
        print('[Answer]')
        for i in range(len(answs)):
            print('%d.' % (i + 1), answs[i])

        input('\nPress enter to continue\n')
        while not env.is_done():
            if len(env.memory) < cfg.memory_num - 1:
                env._append_current()
                env.sent_ptr += 1
            else:
                if cfg.model == 'LIFO':
                    break
                env._append_current()
                env.sent_ptr += 1

                batch, solvable, mem_solvable = env.observe()
                batch = {k: v.cuda() for k, v in batch.items()}

                result = model.mem_forward(**batch)
                logit, value = result['logit'], result['value']

                prob = F.softmax(logit, 1)
                _, action = prob.max(1, keepdim=True)

                env.step(action=action.item(), **result)

                _print_mem_result(cfg, tokenizer, batch, prob, action, result,
                                  solvable, mem_solvable, answs)
                # input()

        assert (len(env.memory) <= cfg.memory_num)
        result = _qa_forward(env, model)
        batch = result['batch']
        _print_qa_result(cfg, tokenizer, batch, result, answs, env)
        input()
Beispiel #12
0
def main():
    parser = argparse.ArgumentParser()
    # # 必要参数
    parser.add_argument('--task',
                        default='multi',
                        type=str,
                        help='Task affecting load data and vectorize feature')
    parser.add_argument(
        '--loss_type',
        default='double',
        type=str,
        help='Select loss double or single, only for multi task'
    )  # 针对multi任务才有效
    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help=
        "Bert pre-trained model selected in the list: bert-base-uncased,bert-large-uncased, "
        "bert-base-cased, bert-large-cased, bert-base-multilingual-uncased,bert-base-chinese,"
        "bert-base-multilingual-cased.")  # 选择预训练模型参数
    parser.add_argument("--debug",
                        default=False,
                        help="Whether run on small dataset")  # 正常情况下都应该选择false
    parser.add_argument(
        "--output_dir",
        default="./SQuAD/output/",
        type=str,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )

    # # 其他参数
    parser.add_argument("--train_file",
                        default="./SQuAD/version/train.json",
                        type=str,
                        help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument(
        "--predict_file",
        default="./SQuAD/version/prediction.json",
        type=str,
        help=
        "SQuAD json for predictio ns. E.g., dev-v1.1.json or test-v1.1.json")

    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will be "
        "truncated to this length.")

    # # 控制参数
    parser.add_argument("--do_train",
                        default=True,
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        default=True,
                        help="Whether to run eval on the dev set.")

    parser.add_argument("--train_batch_size",
                        default=18,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=18,
                        type=int,
                        help="Total batch size for predictions.")

    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json file."
    )
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated.This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument(
        "--verbose_logging",
        default=False,
        help=
        "If true, all of the warnings related to data processing will be printed.A number of "
        "warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        default=False,
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        default=False,
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.Positive power of 2: static loss scaling value.\n"
    )
    parser.add_argument(
        '--version_2_with_negative',
        default=False,
        help=
        'If true, the SQuAD examples contain some that do not have an answer.')
    parser.add_argument(
        '--null_score_diff_threshold',
        type=float,
        default=0.0,
        help=
        "If null_score - best_non_null is greater than the threshold predict null."
    )
    args = parser.parse_args()

    # if是采用单机形式,else采用的是分布式形式;因为我们没有分布式系统,所以采用单机多GPU的方式进行训练10.24
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='hierarchical_copy')

    # 以下三句话的意义不是很大,基本操作这一部分是日志的输出形式10.24
    logging.basicConfig(
        format='%(asctime)s-%(levelname)s-%(name)s-%(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.info(
        "device:{}, n_gpu:{}, distributed training:{}, 16-bits training:{}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))
    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    # 以下几行均是用来设置参数10.24
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
    random.seed(args.seed)  # 设置随机种子
    np.random.seed(args.seed)  # 设置随机种子
    torch.manual_seed(args.seed)  # 为CPU设置种子用于生成随机数,以使得结果是确定的
    if n_gpu > 0:  # 如果使用多个GPU,应该使用torch.cuda.manual_seed_all()为所有的GPU设置种子
        torch.cuda.manual_seed_all(args.seed)

    # 以下三句又是基本操作,意义不大10.24
    if not args.do_train and not args.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")
    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )

    # 以下2句是用来判断output_dir是否存在,若不存在,则创建即可(感觉有这个东西反而不太好,因为需要空文件夹)10.24
    # if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
    #     raise ValueError("Output directory () already exists and is not empty.")
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # 这个东西是用来干啥的(从tokenization中读取,对Tokenizer进行初始化操作)10.24
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    # 从data中读取数据的方式,一种是单队列的读取方式,另一种是多通道读取方式10.24
    if args.task == 'squad':
        read_examples = read_squad_examples
    elif args.task == 'multi':
        read_examples = read_multi_examples

    # 用来加载训练样例以及优化的步骤10.24
    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = read_examples(
            input_file=args.train_file,
            is_training=True,
            version_2_with_negative=args.version_2_with_negative)
        if args.debug:
            train_examples = train_examples[:100]
        num_train_optimization_steps = \
            int(len(train_examples)/args.train_batch_size/args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # 模型准备中ing10.24
    model = BertForQuestionAnswering.from_pretrained(
        args.bert_model,
        cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                               'distributed_{}'.format(args.local_rank)))

    # model = torch.nn.DataParallel(model).cuda()
    # 判断是否使用float16编码10.24
    if args.fp16:
        # model.half().cuda()
        model.half()
        # 将模型加载到相应的CPU或者GPU中10.24
    model.to(device)

    # 配置优化器等函数10.24
    if args.do_train:
        param_optimizer = list(model.named_parameters())

        # hack to remove pooler, which is not used
        # thus it produce None grad that break apex
        param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        if args.fp16:
            try:
                # from apex.optimizers import FP16_Optimizer
                from apex.fp16_utils import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=args.learning_rate,
                                  bias_correction=True)
            if args.loss_scale == 0:
                optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
            else:
                optimizer = FP16_Optimizer(optimizer,
                                           static_loss_scale=args.loss_scale)
            warmup_linear = WarmupLinearSchedule(
                warmup=args.warmup_proportion,
                t_total=num_train_optimization_steps)
        else:
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 warmup=args.warmup_proportion,
                                 t_total=num_train_optimization_steps)

    # 进行模型的拟合训练10.24
    global_step = 0
    if args.do_train:
        # 训练语料的特征提取
        train_features = convert_examples_to_features(
            examples=train_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=True)

        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_start_positions = torch.tensor(
            [f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor(
            [f.end_position for f in train_features], dtype=torch.long)
        all_start_vector = torch.tensor(
            [f.start_vector for f in train_features], dtype=torch.float)
        all_end_vector = torch.tensor([f.end_vector for f in train_features],
                                      dtype=torch.float)
        all_content_vector = torch.tensor(
            [f.content_vector for f in train_features], dtype=torch.float)

        # # 替换的内容all_start_positions以及all_end_positions
        # all1_start_positions = []
        # for i in range(len(train_features)):
        #     for j in range(len(train_features[i].start_position)):
        #         all1_start_positions.append(train_features[i].start_position[j])
        # all_start_positions = torch.tensor([k for k in all1_start_positions], dtype=torch.long)
        # all1_end_positions = []
        # for i in range(len(train_features)):
        #     for j in range(len(train_features[i].end_position)):
        #         all1_end_positions.append(train_features[i].end_position[j])
        # all_end_positions = torch.tensor([k for k in all1_end_positions], dtype=torch.long)
        # ####################################################################

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_start_positions,
                                   all_end_positions, all_start_vector,
                                   all_end_vector, all_content_vector)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)  # 随机采样器
        else:
            train_sampler = DistributedSampler(train_data)

        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for ep in trange(int(args.num_train_epochs), desc="Epoch"):
            # 每次都叫他进行分发,这样的话,就可以进行多GPU训练
            model = torch.nn.DataParallel(model).cuda()
            for step, batch in enumerate(
                    tqdm(train_dataloader,
                         desc="Iteration",
                         disable=args.local_rank not in [-1, 0])):

                if n_gpu == 1:
                    batch = tuple(
                        t.to(device)
                        for t in batch)  # multi-gpu does scattering it-self
                input_ids, input_mask, segment_ids, start_positions, end_positions, start_vector, end_vector, content_vector = batch

                loss = model(input_ids, segment_ids, input_mask,
                             start_positions, end_positions, start_vector,
                             end_vector, content_vector, args.loss_type)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                    print("loss率为:{}".format(loss))
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used and handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear.get_lr(
                            global_step, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            print("\n")
            print(ep)
            output_model_file = os.path.join(args.output_dir,
                                             str(ep) + WEIGHTS_NAME)
            output_config_file = os.path.join(args.output_dir,
                                              str(ep) + CONFIG_NAME)

            torch.save(model.state_dict(), output_model_file)
            if isinstance(model, torch.nn.DataParallel):
                model = model.module
            model.config.to_json_file(output_config_file)
            tokenizer.save_vocabulary(args.output_dir)

    # 这个是用来加载进行微调调好后的代码以方便进行预测10.25
    if args.do_train and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):
        # Save a trained model, configuration and tokenizer
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        tokenizer.save_vocabulary(args.output_dir)

        # Load a trained model and vocabulary that you have fine-tuned
        model = BertForQuestionAnswering.from_pretrained(args.output_dir)
        tokenizer = BertTokenizer.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)
    else:
        model = BertForQuestionAnswering.from_pretrained(args.output_dir)
        tokenizer = BertTokenizer.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)

    # 再次将GPU加入10.25
    model.to(device)

    # 这部分就是进行相应的预测(用于生成预测文件)
    if args.do_predict and (args.local_rank == -1
                            or torch.distributed.get_rank() == 0):
        eval_examples = \
            read_examples(input_file=args.predict_file, is_training=False, version_2_with_negative=args.version_2_with_negative)
        if args.debug:
            eval_examples = eval_examples[:100]
        eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=False)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(eval_examples))
        logger.info("  Num split examples = %d", len(eval_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_example_index = torch.arange(all_input_ids.size(0),
                                         dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_example_index)

        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start evaluating")
        for input_ids, input_mask, segment_ids, example_indices in tqdm(
                eval_dataloader,
                desc="Evaluating",
                disable=args.local_rank not in [-1, 0]):
            if len(all_results) % 1000 == 0:
                logger.info("Processing example: %d" % (len(all_results)))
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            with torch.no_grad():
                batch_start_logits, batch_end_logits = model(
                    input_ids, segment_ids, input_mask)
            for i, example_index in enumerate(example_indices):
                start_logits = batch_start_logits[i].detach().cpu().tolist()
                end_logits = batch_end_logits[i].detach().cpu().tolist()
                eval_feature = eval_features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                all_results.append(
                    RawResult(unique_id=unique_id,
                              start_logits=start_logits,
                              end_logits=end_logits))

        middle_result = os.path.join(args.output_dir, 'middle_result.pkl')
        pickle.dump([eval_examples, eval_features, all_results],
                    open(middle_result, 'wb'))

        output_prediction_file = os.path.join(args.output_dir,
                                              "predictions.json")
        output_nbest_file = os.path.join(args.output_dir,
                                         "nbest_predictions.json")
        output_null_log_odds_file = os.path.join(args.output_dir,
                                                 "null_odds.json")

        if (args.loss_type == 'double'):
            write_predictions_couple_labeling(
                eval_examples, eval_features, all_results, args.n_best_size,
                args.max_answer_length, args.do_lower_case,
                output_prediction_file, output_nbest_file,
                output_null_log_odds_file, args.verbose_logging,
                args.version_2_with_negative, args.null_score_diff_threshold)
        elif (args.loss_type == 'single'):
            write_predictions_single_labeling(
                eval_examples, eval_features, all_results, args.n_best_size,
                args.max_answer_length, args.do_lower_case,
                output_prediction_file, output_nbest_file,
                output_null_log_odds_file, args.verbose_logging,
                args.version_2_with_negative, args.null_score_diff_threshold)
        elif (args.loss_type == 'origin') or (args.task == 'multi'
                                              and args.loss_type == 'squad'):
            write_predictions(eval_examples, eval_features, all_results,
                              args.n_best_size, args.max_answer_length,
                              args.do_lower_case, output_prediction_file,
                              output_nbest_file, output_null_log_odds_file,
                              args.verbose_logging,
                              args.version_2_with_negative,
                              args.null_score_diff_threshold)
        else:
            raise ValueError('{} dataset and {} loss is not support'.format(
                args.task, args.loss_type))
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument("--word_embedding_file",
                        default='./emb/wiki-news-300d-1M.vec',
                        type=str,
                        help="The input directory of word embeddings.")
    parser.add_argument("--index_path",
                        default='./emb/p_index.bin',
                        type=str,
                        help="The input directory of word embedding index.")
    parser.add_argument("--word_embedding_info",
                        default='./emb/vocab_info.txt',
                        type=str,
                        help="The input directory of word embedding info.")
    parser.add_argument("--data_file",
                        default='',
                        type=str,
                        help="The input directory of input data file.")

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--max_ngram_length",
                        default=16,
                        type=int,
                        help="The maximum total ngram sequence")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--embedding_size",
                        default=300,
                        type=int,
                        help="Total batch size for embeddings.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--num_eval_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of eval epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--single',
                        action='store_true',
                        help="Whether only evaluate a single epoch")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()
    # Comment the if else block for no CUDA
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #device = torch.device("cpu") # uncomment this for no GPU
    logger.info(
        "device: {} , distributed training: {}, 16-bits training: {}".format(
            device, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:  # Comment this to No GPU
        torch.cuda.manual_seed_all(args.seed)  # Comment this for No GPU

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    w2i, i2w, vocab_size = {}, {}, 1
    if args.do_train:

        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

        train_features, w2i, i2w, vocab_size = convert_examples_to_features_disc_train(
            train_examples, label_list, args.max_seq_length, tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Num token vocab = %d", vocab_size)
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_tokens = torch.tensor([f.token_ids for f in train_features],
                                  dtype=torch.long)
        all_label_id = torch.tensor([f.label_id for f in train_features],
                                    dtype=torch.long)

    # load embeddings sa
    if args.do_train:
        logger.info("Loading word embeddings ... ")
        emb_dict, emb_vec, vocab_list, emb_vocab_size = load_vectors(
            args.word_embedding_file)
        if not os.path.exists(args.index_path):

            write_vocab_info(args.word_embedding_info, emb_vocab_size,
                             vocab_list)
            p = load_embeddings_and_save_index(range(emb_vocab_size), emb_vec,
                                               args.index_path)
        else:
            #emb_vocab_size, vocab_list = load_vocab_info(args.word_embedding_info)
            p = load_embedding_index(args.index_path,
                                     emb_vocab_size,
                                     num_dim=args.embedding_size)
        #emb_dict, emb_vec, vocab_list, emb_vocab_size, p = None, None, None, None, None

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(
            args.local_rank))
    model = BertForDiscriminator.from_pretrained(args.bert_model,
                                                 cache_dir=cache_dir,
                                                 num_labels=num_labels)
    model.to(device)

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:  # Comment this for NO GPU
        model = torch.nn.DataParallel(model)  # Comment this for NO GPU

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 1
    tr_loss = 0
    if args.do_train:

        train_data = TensorDataset(all_tokens, all_label_id)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for ind in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0
            flaw_eval_f1 = []
            flaw_eval_recall = []
            flaw_eval_precision = []
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                tokens, _ = batch  #, label_id, ngram_ids, ngram_labels, ngram_masks

                # module1: learn a discriminator
                tokens = tokens.to('cpu').numpy()
                #print("PRINTING TOKENS!!!!!!!!! ", len(tokens[0]))
                train_features = convert_examples_to_features_flaw(
                    tokens, args.max_seq_length, args.max_ngram_length,
                    tokenizer, i2w, emb_dict, p, vocab_list)

                flaw_mask = torch.tensor([f.flaw_mask for f in train_features],
                                         dtype=torch.long).to(
                                             device)  # [1, 1, 1, 1, 0,0,0,0]
                flaw_ids = torch.tensor([f.flaw_ids for f in train_features],
                                        dtype=torch.long).to(
                                            device)  # [12,25,37,54,0,0,0,0]
                flaw_labels = torch.tensor(
                    [f.flaw_labels for f in train_features],
                    dtype=torch.long).to(device)  # [0, 1, 1, 1, 0,0,0,0]

                loss, logits = model(flaw_ids, flaw_mask, flaw_labels)
                logits = logits.detach().cpu().numpy()

                if n_gpu > 1:  # Comment this for NO GPU
                    loss = loss.mean()  # Comment this for NO GPU

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()

                nb_tr_examples += flaw_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:

                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                # eval during training
                flaw_labels = flaw_labels.to('cpu').numpy()

                flaw_tmp_eval_f1, flaw_tmp_eval_recall, flaw_tmp_eval_precision = f1_3d(
                    logits, flaw_labels)
                flaw_eval_f1.append(flaw_tmp_eval_f1)
                flaw_eval_recall.append(flaw_tmp_eval_recall)
                flaw_eval_precision.append(flaw_tmp_eval_precision)

                nb_eval_examples += flaw_ids.size(0)
                nb_eval_steps += 1

            flaw_f1 = sum(flaw_eval_f1) / len(flaw_eval_f1)
            flaw_recall = sum(flaw_eval_recall) / len(flaw_eval_recall)
            flaw_precision = sum(flaw_eval_precision) / len(
                flaw_eval_precision)
            loss = tr_loss / nb_tr_steps if args.do_train else None
            result = {
                'flaw_f1': flaw_f1,
                "flaw_recall": flaw_recall,
                "flaw_precision": flaw_precision,
                'loss': loss,
            }

            output_eval_file = os.path.join(args.output_dir,
                                            "train_results.txt")
            with open(output_eval_file, "a") as writer:
                #logger.info("***** Training results *****")
                writer.write("epoch" + str(ind) + '\n')
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
                writer.write('\n')

            model_to_save = model.module if hasattr(model, 'module') else model
            output_model_file = os.path.join(args.output_dir,
                                             "epoch" + str(ind) + WEIGHTS_NAME)
            torch.save(model_to_save.state_dict(), output_model_file)
            output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
            with open(output_config_file, 'w') as f:
                f.write(model_to_save.config.to_json_string())

        os.rename(
            output_model_file,
            os.path.join(args.output_dir, "disc_trained_" + WEIGHTS_NAME))
        current_path = os.path.join(args.output_dir,
                                    "disc_trained_" + WEIGHTS_NAME)
        new_path = os.path.join('./models', "disc_trained_" + WEIGHTS_NAME)
        new_path_config = os.path.join('./models' + CONFIG_NAME)
        shutil.move(current_path, new_path)
        shutil.move(output_config_file, new_path_config)

    if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank()
                         == 0):  # for trouble-shooting

        eval_examples = processor.get_disc_dev_examples(args.data_file)
        eval_features, w2i, i2w, vocab_size = convert_examples_to_features_disc_eval(
            eval_examples, label_list, args.max_seq_length, tokenizer, w2i,
            i2w, vocab_size)

        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Num token vocab = %d", vocab_size)
        logger.info("  Batch size = %d", args.eval_batch_size)

        all_token_ids = torch.tensor([f.token_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_flaw_labels = torch.tensor([f.flaw_labels for f in eval_features],
                                       dtype=torch.long)
        all_flaw_ids = torch.tensor([f.flaw_ids for f in eval_features],
                                    dtype=torch.long)
        all_label_id = torch.tensor([f.label_id for f in eval_features],
                                    dtype=torch.long)
        all_chunks = torch.tensor([f.chunks for f in eval_features],
                                  dtype=torch.long)
        #print("flaw ids in eval_features: ", all_flaw_ids)

        eval_data = TensorDataset(all_token_ids, all_input_ids, all_input_mask,
                                  all_flaw_ids, all_flaw_labels, all_label_id,
                                  all_chunks)

        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        # Load a trained model and config that you have fine-tuned
        if args.single:
            eval_range = trange(int(args.num_eval_epochs),
                                int(args.num_eval_epochs + 1),
                                desc="Epoch")
        else:
            eval_range = trange(int(args.num_eval_epochs), desc="Epoch")

        attack_type = 'rand'
        for epoch in eval_range:

            output_file = os.path.join(
                args.data_dir, "epoch" + str(epoch) + "disc_eval_outputs_" +
                attack_type + ".tsv")
            with open(output_file, "w") as csv_file:
                writer = csv.writer(csv_file, delimiter='\t')
                writer.writerow(["sentence", "label", "ids"])

            #output_model_file = os.path.join(args.output_dir, "epoch"+str(epoch)+WEIGHTS_NAME)
            output_model_file = os.path.join(args.output_dir,
                                             "disc_trained_" + WEIGHTS_NAME)
            output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
            #print("output_model_file: ", output_model_file)
            config = BertConfig(output_config_file)
            model = BertForDiscriminator(config, num_labels=num_labels)
            model.load_state_dict(torch.load(output_model_file))

            model.to(device)
            model.eval()
            predictions, truths = [], []
            eval_loss, nb_eval_steps, nb_eval_examples = 0, 0, 0
            eval_accuracy = 0

            for token_ids, input_ids, input_mask, flaw_ids, flaw_labels, label_id, chunks in tqdm(
                    eval_dataloader, desc="Evaluating"):

                token_ids = token_ids.to(device)
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                flaw_labels = flaw_labels.to(device)
                flaw_ids = flaw_ids.to(device)

                #print("flaw ids in eval_dataloader: ", flaw_ids)

                with torch.no_grad():
                    tmp_eval_loss, s = model(input_ids, input_mask,
                                             flaw_labels)

                    #                     print("tmp_eval_loss: ",tmp_eval_loss)
                    #                     print("s: ",s)

                    logits = model(input_ids, input_mask)

                    print("len of logits: ", len(logits))
                    print("shape of logits: ", logits.size())
                    print("type of logits: ", type(logits))
                    print("type of logits: ", logits)

                    flaw_logits = torch.argmax(logits, dim=2)

                    print("Type of flaw_logits: ", type(flaw_logits))
                    print("shape of flaw_logits: ", flaw_logits.size())
                    print("Length of flaw_logits: ", len(flaw_logits))
                    print("flaw_logits: ", flaw_logits)

                logits = logits.detach().cpu().numpy()
                flaw_logits = flaw_logits.detach().cpu().numpy()
                flaw_ids = flaw_ids.to('cpu').numpy()
                label_id = label_id.to('cpu').numpy()
                chunks = chunks.to('cpu').numpy()
                token_ids = token_ids.to('cpu').numpy()

                flaw_logits = logit_converter(
                    flaw_logits, chunks)  # each word only has one '1'

                print("Type of flaw_logits logit_converter: ",
                      type(flaw_logits))
                #print("shape of flaw_logits logit_converter : ",flaw_logits.size())
                print("Length of flaw_logits logit_converter : ",
                      len(flaw_logits))
                print("flaw_logits logit_converter : ", flaw_logits)

                true_logits = []

                #print("length of flaw_ids: ",len(flaw_ids))

                for i in range(len(flaw_ids)):
                    tmp = [0] * len(flaw_logits[i])

                    #print("tmp: ",tmp) # ne line
                    #print("printing i:",i)
                    #print("len of tmp: ",len(tmp))
                    #print("length of flaw_ids of i : ",len(flaw_ids[i]))
                    #print("flaw_ids[i]: ",flaw_ids[i])

                    for j in range(len(flaw_ids[0])):
                        #print("flaw_ids[i][j] : ",flaw_ids[i][j])
                        #print("tmp value: ", tmp)
                        #print("tmp len: ", len(tmp))
                        if flaw_ids[i][j] == 0: break
                        if flaw_ids[i][j] >= len(tmp): continue
                        tmp[flaw_ids[i][j]] = 1

                    true_logits.append(tmp)
                    #print('true_logits: ', true_logits)

                tmp_eval_accuracy = accuracy_2d(flaw_logits, true_logits)
                eval_accuracy += tmp_eval_accuracy

                predictions += true_logits  # Original
                truths += flaw_logits  # Original
                #predictions += flaw_logits # for trouble-shooting
                #truths += true_logits # for trouble-shooting
                eval_loss += tmp_eval_loss.mean().item()
                nb_eval_examples += input_ids.size(0)
                nb_eval_steps += 1

                with open(output_file, "a") as csv_file:
                    for i in range(len(label_id)):
                        #print("i in write output file:",i)
                        token = ' '.join(
                            [i2w[x] for x in token_ids[i] if x != 0])
                        flaw_logit = flaw_logits[i]
                        #print("flaw_logit in write output file: ",flaw_logit)
                        label = str(label_id[i])
                        logit = ','.join([
                            str(i) for i, x in enumerate(flaw_logit) if x == 1
                        ])  # for trouble-shooting
                        logit = '-1' if logit == '' else logit  # for trouble-shooting
                        writer = csv.writer(csv_file, delimiter='\t')
                        writer.writerow([token, label, logit])

                # Renaming and moving the file for Embedding Estimator

            eval_loss = eval_loss / nb_eval_steps
            eval_accuracy = eval_accuracy / nb_eval_steps
            eval_f1_score, eval_recall_score, eval_precision_score = f1_2d(
                truths, predictions)
            loss = tr_loss / nb_tr_steps if args.do_train else None
            result = {
                'eval_loss': eval_loss,
                'eval_f1': eval_f1_score,
                'eval_recall': eval_recall_score,
                'eval_precision': eval_precision_score,
                'eval_acc': eval_accuracy
            }

            output_eval_file = os.path.join(
                args.output_dir,
                "disc_eval_results_" + attack_type + "_attacks.txt")
            with open(output_eval_file, "a") as writer:
                logger.info("***** Eval results *****")
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))

            #attack_type='drop'
            new_path = os.path.join(
                args.data_dir, "disc_eval_outputs_" + attack_type + ".tsv")
            current_path = os.path.join(
                args.data_dir, "epoch" + str(epoch) + "disc_eval_outputs_" +
                attack_type + ".tsv")
            os.rename(current_path, new_path)
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--input_text',
                        type=str,
                        default='',
                        help="Input text.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    args.device = device

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    # if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
    #     raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    label_list = processor.get_labels()
    num_labels = len(label_list)

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    model = BertForSequenceClassification.from_pretrained(
        args.bert_model, num_labels=num_labels)
    if args.local_rank == 0:
        torch.distributed.barrier()

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0

    # Load a trained model and vocabulary that you have fine-tuned
    model = BertForSequenceClassification.from_pretrained(
        args.output_dir, num_labels=num_labels)
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    # model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)

    model.to(device)

    ### Evaluation
    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        print("input text: ", args.input_text)
        eval_examples = processor.get_test_example(args.input_text)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer, output_mode)

        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.float)

        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        if args.local_rank == -1:
            eval_sampler = SequentialSampler(eval_data)
        else:
            eval_sampler = DistributedSampler(
                eval_data)  # Note that this sampler samples randomly
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss = 0
        nb_eval_steps = 0
        preds = []
        out_label_ids = None

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids,
                               token_type_ids=segment_ids,
                               attention_mask=input_mask)

            # create eval loss and other metric required by the task
            if output_mode == "classification":
                loss_fct = CrossEntropyLoss()
                tmp_eval_loss = loss_fct(logits.view(-1, num_labels),
                                         label_ids.view(-1))
            elif output_mode == "regression":
                loss_fct = MSELoss()
                tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))

            eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
                out_label_ids = label_ids.detach().cpu().numpy()
            else:
                preds[0] = np.append(preds[0],
                                     logits.detach().cpu().numpy(),
                                     axis=0)
                out_label_ids = np.append(out_label_ids,
                                          label_ids.detach().cpu().numpy(),
                                          axis=0)

        eval_loss = eval_loss / nb_eval_steps
        preds = preds[0]

        print(preds.shape)
        print("preds", preds)

        if output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif output_mode == "regression":
            preds = np.squeeze(preds)
        for feature in features:
            for i in feature['inputs']:
                fout.write(i + ' ')
            fout.write('\n')
    with open("./train_ytrans.data", 'w', encoding="utf-8") as fout:
        for feature in features:
            # fout.write(feature['answer'] + '\n')
            for i in feature['answer']:
                fout.write(i + ' ')
            fout.write('\n')

    return features


if __name__ == "__main__":
    tokenizer = BertTokenizer.from_pretrained('../roberta_wwm_ext',
                                              do_lower_case=True)
    # 生成训练数据, train.data
    # print(len(tokenizer.vocab))     # 21128, the number of tne vocab
    examples = read_squad_examples(zhidao_input_file=args.zhidao_input_file,
                                   search_input_file=args.search_input_file)
    features = convert_examples_to_features(
        examples=examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        max_query_length=args.max_query_length,
        max_ans_length=args.max_ans_length)

    # 生成验证数据, dev.data。记得注释掉生成训练数据的代码,并在196行将train.data改为dev.data
    # examples = read_squad_examples(zhidao_input_file=args.dev_zhidao_input_file,
    #                                search_input_file=args.dev_search_input_file)
    # features = convert_examples_to_features(examples=examples, tokenizer=tokenizer,
Beispiel #16
0
            "input_mask": input_mask,
            "segment_ids": segment_ids,
            "start_position": start_position,
            "end_position": end_position
        })

    with open(filepath, 'w', encoding="utf-8") as fout:
        for feature in features:
            fout.write(json.dumps(feature, ensure_ascii=False) + '\n')
    print("len(features):", len(features))
    return features


if __name__ == "__main__":

    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese',
                                              do_lower_case=True)
    # 生成训练数据, train.data
    examples = read_squad_examples(zhidao_input_file=args.zhidao_input_file,
                                   search_input_file=args.search_input_file)
    features = convert_examples_to_features(
        filepath=TRAINSET_PATH,
        examples=examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        max_query_length=args.max_query_length)

    # 生成验证数据, dev.data。记得注释掉生成训练数据的代码,并在196行将train.data改为dev.data
    examples = read_squad_examples(
        zhidao_input_file=args.dev_zhidao_input_file,
        search_input_file=args.dev_search_input_file)
    features = convert_examples_to_features(
def main():

    parser = argparse.ArgumentParser()
    ## Required parameters
    parser.add_argument("--vocab_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The vocabulary the BERT model will train on.")
    parser.add_argument(
        "--input_file",
        default=None,
        type=str,
        required=True,
        help=
        "The input train corpus. can be directory with .txt files or a path to a single file"
    )
    parser.add_argument(
        "--output_file",
        default=None,
        type=str,
        required=True,
        help="The output file where the model checkpoints will be written.")

    ## Other parameters

    # str
    parser.add_argument(
        "--bert_model",
        default="bert-large-uncased",
        type=str,
        required=False,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )

    #int
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--dupe_factor",
        default=10,
        type=int,
        help=
        "Number of times to duplicate the input data (with different masks).")
    parser.add_argument("--max_predictions_per_seq",
                        default=20,
                        type=int,
                        help="Maximum sequence length.")

    # floats

    parser.add_argument("--masked_lm_prob",
                        default=0.15,
                        type=float,
                        help="Masked LM probability.")

    parser.add_argument(
        "--short_seq_prob",
        default=0.1,
        type=float,
        help=
        "Probability to create a sequence shorter than maximum sequence length"
    )

    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        default=True,
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument('--random_seed',
                        type=int,
                        default=12345,
                        help="random seed for initialization")

    args = parser.parse_args()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    input_files = []
    if os.path.isfile(args.input_file):
        input_files.append(args.input_file)
    elif os.path.isdir(args.input_file):
        input_files = [
            os.path.join(args.input_file, f)
            for f in os.listdir(args.input_file)
            if (os.path.isfile(os.path.join(args.input_file, f))
                and f.endswith('.txt'))
        ]
    else:
        raise ValueError("{} is not a valid path".format(args.input_file))

    rng = random.Random(args.random_seed)
    instances = create_training_instances(input_files, tokenizer,
                                          args.max_seq_length,
                                          args.dupe_factor,
                                          args.short_seq_prob,
                                          args.masked_lm_prob,
                                          args.max_predictions_per_seq, rng)

    output_file = args.output_file

    write_instance_to_example_file(instances, tokenizer, args.max_seq_length,
                                   args.max_predictions_per_seq, output_file)
Beispiel #18
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--src_file", default=None, type=str,
                        help="The input data file name.")
    parser.add_argument("--tgt_file", default=None, type=str,
                        help="The output data file name.")

    parser.add_argument("--dev_src_file", default=None, type=str,
                        help="The input data file name.")
    parser.add_argument("--dev_tgt_file", default=None, type=str,
                        help="The output data file name.")

    parser.add_argument("--ks_src_file", default=None, type=str,
                        help="The input data file name.")
    parser.add_argument("--ks_tgt_file", default=None, type=str,
                        help="The output data file name.")

    parser.add_argument("--ks_dev_src_file", default=None, type=str,
                        help="The input data file name.")
    parser.add_argument("--ks_dev_tgt_file", default=None, type=str,
                        help="The output data file name.")

    parser.add_argument("--predict_input_file", default=None, type=str,
                        help="predict_input_file")
    parser.add_argument("--predict_output_file", default=None, type=str,
                        help="predict_output_file")


    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
    parser.add_argument("--config_path", default=None, type=str,
                        help="Bert config file path.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--log_dir",
                        default='',
                        type=str,
                        required=True,
                        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")

    parser.add_argument("--predict_bleu",
                        default=0.5,
                        type=float,
                        help="The Predicted Bleu for KS Predict ")

    parser.add_argument("--train_vae",
                        action='store_true',
                        help="Whether to train vae.")
    # Other parameters
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        action='store_true',
                        help="Whether to run ks predict.")

    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate", default=5e-5, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing", default=0, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob", default=0.1, type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--fp32_embedding', action='store_true',
                        help="Whether to use 32-bit float precision instead of 16-bit for embeddings")
    parser.add_argument('--loss_scale', type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp', action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument('--from_scratch', action='store_true',
                        help="Initialize parameters with random values (i.e., training from scratch).")
    parser.add_argument('--new_segment_ids', action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids', action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input', action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a', type=int, default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b', type=int, default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument('--trunc_seg', default='',
                        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument('--always_truncate_tail', action='store_true',
                        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument("--mask_prob", default=0.15, type=float,
                        help="Number of prediction is sometimes less than max_pred when sequence is short.")
    parser.add_argument("--mask_prob_eos", default=0, type=float,
                        help="Number of prediction is sometimes less than max_pred when sequence is short.")
    parser.add_argument('--max_pred', type=int, default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers", default=0, type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words', action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb', type=float, default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size', type=int, default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word', action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training', action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument('--has_sentence_oracle', action='store_true',
                        help="Whether to have sentence level oracle for training. "
                             "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings', type=int, default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection', action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type', default=0, type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv', default=0, type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb', action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument('--s2s_special_token', action='store_true',
                        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment', action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument('--s2s_share_segment', action='store_true',
                        help="Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment).")
    parser.add_argument('--pos_shift', action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()

    assert Path(args.model_recover_path).exists(
    ), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace(
        '[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace(
        '[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)

    handler = logging.FileHandler(os.path.join(args.log_dir, "train.log"), encoding='UTF-8')
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)

    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)

    logger.addHandler(handler)
    logger.addHandler(console)


    json.dump(args.__dict__, open(os.path.join(
        args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    args.train_batch_size = int(
        args.train_batch_size / args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model, do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()


    print("Loading QKR Train Dataset", args.data_dir)
    bi_uni_pipeline = [seq2seq_loader.Preprocess4Seq2seq(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys(
    )), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail}, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift)]
    file_oracle = None
    if args.has_sentence_oracle:
        file_oracle = os.path.join(args.data_dir, 'train.oracle')
    fn_src = os.path.join(
        args.data_dir, args.src_file if args.src_file else 'train.src')
    fn_tgt = os.path.join(
        args.data_dir, args.tgt_file if args.tgt_file else 'train.tgt')
    train_dataset = seq2seq_loader.Seq2SeqDataset(
        fn_src, fn_tgt, args.train_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle, bi_uni_pipeline=bi_uni_pipeline)
    if args.local_rank == -1:
        train_sampler = RandomSampler(train_dataset, replacement=False)
        _batch_size = args.train_batch_size
    else:
        train_sampler = DistributedSampler(train_dataset)
        _batch_size = args.train_batch_size // dist.get_world_size()
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=_batch_size, sampler=train_sampler,
                                                   num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False)

    print("Loading KS Train Dataset", args.data_dir)
    ks_fn_src = os.path.join(
        args.data_dir, args.ks_src_file)
    ks_fn_tgt = os.path.join(
        args.data_dir, args.ks_tgt_file)
    ks_train_dataset = seq2seq_loader.Seq2SeqDataset(
        ks_fn_src, ks_fn_tgt, args.train_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle,
        bi_uni_pipeline=bi_uni_pipeline)
    if args.local_rank == -1:
        ks_train_sampler = RandomSampler(ks_train_dataset, replacement=False)
        _batch_size = args.train_batch_size
    else:
        ks_train_sampler = DistributedSampler(ks_train_dataset)
        _batch_size = args.train_batch_size // dist.get_world_size()
    ks_train_dataloader = torch.utils.data.DataLoader(ks_train_dataset, batch_size=_batch_size, sampler=ks_train_sampler,
                                                   num_workers=args.num_workers,
                                                   collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
                                                   pin_memory=False)


    logger.info("Loading QKR Eval Dataset from {}".format(args.data_dir))

    fn_src = os.path.join(
        args.data_dir, args.dev_src_file)
    fn_tgt = os.path.join(
        args.data_dir, args.dev_tgt_file)
    dev_reddit_dataset = seq2seq_loader.Seq2SeqDataset(
        fn_src, fn_tgt, args.eval_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle,
        bi_uni_pipeline=bi_uni_pipeline)
    if args.local_rank == -1:
        dev_reddit_sampler = RandomSampler(dev_reddit_dataset, replacement=False)
        _batch_size = args.eval_batch_size
    else:
        dev_reddit_sampler = DistributedSampler(dev_reddit_dataset)
        _batch_size = args.eval_batch_size // dist.get_world_size()
    dev_reddit_dataloader = torch.utils.data.DataLoader(dev_reddit_dataset, batch_size=_batch_size,
                                                        sampler=dev_reddit_sampler,
                                                        num_workers=args.num_workers,
                                                        collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
                                                        pin_memory=False)

    logger.info("Loading KS Eval Dataset from {}".format(args.data_dir))

    ks_dev_fn_src = os.path.join(
        args.data_dir, args.ks_dev_src_file)
    ks_dev_fn_tgt = os.path.join(
        args.data_dir, args.ks_dev_tgt_file)
    ks_dev_reddit_dataset = seq2seq_loader.Seq2SeqDataset(
        ks_dev_fn_src, ks_dev_fn_tgt, args.eval_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle,
        bi_uni_pipeline=bi_uni_pipeline)
    if args.local_rank == -1:
        ks_dev_reddit_sampler = RandomSampler(ks_dev_reddit_dataset, replacement=False)
        _batch_size = args.eval_batch_size
    else:
        ks_dev_reddit_sampler = DistributedSampler(ks_dev_reddit_dataset)
        _batch_size = args.eval_batch_size // dist.get_world_size()
    ks_dev_reddit_dataloader = torch.utils.data.DataLoader(ks_dev_reddit_dataset, batch_size=_batch_size,
                                                        sampler=ks_dev_reddit_sampler,
                                                        num_workers=args.num_workers,
                                                        collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
                                                        pin_memory=False)


    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(len(train_dataloader) * args.num_train_epochs /
                  args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)), map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(
                recover_step * t_total / args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(
                args.model_recover_path, map_location='cpu')
            global_step = 0
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb)
    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        model.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model, device_ids=[
                    args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(
            nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(
                optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(
                optimizer, static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        KL_weight = 0.0

        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        model.train()
        if recover_step:
            start_epoch = recover_step+1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch, int(args.num_train_epochs)+1, desc="Epoch", disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)


            step = 0
            for batch, ks_batch in zip(train_dataloader,ks_train_dataloader):
                batch = [
                    t.to(device) if t is not None else None for t in batch]

                input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx,labels, ks_labels = batch
                oracle_pos, oracle_weights, oracle_labels = None, None, None
                loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels, mask_qkv=mask_qkv,labels=labels,ks_labels=ks_labels,train_vae=args.train_vae)

                if args.train_vae:
                    masked_lm_loss, next_sentence_loss, KL_loss = loss_tuple
                    if n_gpu > 1:    # mean() to average on multi-gpu.
                        masked_lm_loss = masked_lm_loss.mean()
                        next_sentence_loss = next_sentence_loss.mean()
                        KL_loss = KL_loss.mean()
                else:
                    masked_lm_loss, next_sentence_loss, _ = loss_tuple
                    if n_gpu > 1:    # mean() to average on multi-gpu.
                        masked_lm_loss = masked_lm_loss.mean()
                        next_sentence_loss = next_sentence_loss.mean()

                KL_weight += 1.0 / float(len(ks_train_dataloader))

                if args.train_vae:
                    loss = masked_lm_loss + next_sentence_loss + KL_weight * KL_loss
                else:
                    loss = masked_lm_loss + next_sentence_loss

                logger.info("In{}step, masked_lm_loss:{}".format(step, masked_lm_loss))
                logger.info("In{}step, KL_weight:{}".format(step, KL_weight))
                #logger.info("In{}step, KL_loss:{}".format(step, KL_loss))
                logger.info("******************************************* ")

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                      args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if random.randint(0,0) == 0:
                    ks_batch = [
                        t.to(device) if t is not None else None for t in ks_batch]

                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, labels, ks_labels = ks_batch
                    oracle_pos, oracle_weights, oracle_labels = None, None, None
                    loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos,
                                       masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos,
                                       masked_weights_2=oracle_weights,
                                       masked_labels_2=oracle_labels, mask_qkv=mask_qkv, labels=labels, ks_labels=ks_labels,train_ks=True,train_vae=args.train_vae)
                    if args.train_vae:
                        ks_loss, KS_KL_loss = loss_tuple
                        if n_gpu > 1:  # mean() to average on multi-gpu.
                            ks_loss = ks_loss.mean()
                            KS_KL_loss = KS_KL_loss.mean()
                        loss = ks_loss + KL_weight * KS_KL_loss
                    else:
                        ks_loss, _ = loss_tuple
                        if n_gpu > 1:  # mean() to average on multi-gpu.
                            ks_loss = ks_loss.mean()
                        loss = ks_loss


                    logger.info("In{}step, ks_loss:{}".format(step, ks_loss))
                    #logger.info("In{}step, KS_KL_loss:{}".format(step, KS_KL_loss))
                    logger.info("******************************************* ")

                    # ensure that accumlated gradients are normalized
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    if args.fp16:
                        optimizer.backward(loss)
                        if amp_handle:
                            amp_handle._clear_cache()
                    else:
                        loss.backward()
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        lr_this_step = args.learning_rate * \
                                       warmup_linear(global_step / t_total,
                                                     args.warmup_proportion)
                        if args.fp16:
                            # modify learning rate with special warm up BERT uses
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = lr_this_step
                        optimizer.step()
                        optimizer.zero_grad()
                        global_step += 1

                step += 1
                if (step + 1) % 200 == 0:
                    logger.info("***** Running QKR evaling *****")
                    logger.info("  Batch size = %d", args.eval_batch_size)

                    if args.local_rank != -1:
                        train_sampler.set_epoch(i_epoch)
                    dev_iter_bar = tqdm(dev_reddit_dataloader, desc='Iter (loss=X.XXX)',
                                    disable=args.local_rank not in (-1, 0))
                    total_lm_loss = 0
                    total_kl_loss = 0
                    for qkr_dev_step, batch in enumerate(dev_iter_bar):
                        batch = [
                            t.to(device) if t is not None else None for t in batch]
                        if args.has_sentence_oracle:
                            input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, oracle_pos, oracle_weights, oracle_labels = batch
                        else:
                            input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, labels, ks_labels = batch
                            oracle_pos, oracle_weights, oracle_labels = None, None, None
                        with torch.no_grad():
                            loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next,
                                               masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx,
                                               masked_pos_2=oracle_pos, masked_weights_2=oracle_weights,
                                               masked_labels_2=oracle_labels, mask_qkv=mask_qkv,labels=labels,ks_labels=ks_labels,train_vae=args.train_vae)
                            masked_lm_loss, next_sentence_loss, KL_loss = loss_tuple
                            if n_gpu > 1:  # mean() to average on multi-gpu.
                                # loss = loss.mean()
                                masked_lm_loss = masked_lm_loss.mean()
                                next_sentence_loss = next_sentence_loss.mean()
                                KL_loss = KL_loss.mean()

                            # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                            dev_iter_bar.set_description('Iter (loss=%5.3f)' % masked_lm_loss.item())
                            total_lm_loss += masked_lm_loss.item()
                            total_kl_loss += KL_loss.item()

                            # ensure that accumlated gradients are normalized
                    total_mean_lm_loss = total_lm_loss /(qkr_dev_step + 1)
                    total_mean_kl_loss = total_kl_loss / (qkr_dev_step + 1)

                    logger.info("** ** * Evaling mean loss ** ** * ")
                    logger.info("In{}epoch,dev_lm_loss:{}".format(i_epoch, total_mean_lm_loss))
                    logger.info("In{}epoch,dev_kl_loss:{}".format(i_epoch, total_mean_kl_loss))
                    logger.info("******************************************* ")


                    logger.info("***** Running KS evaling *****")
                    logger.info("  Batch size = %d", args.eval_batch_size)

                    ks_dev_iter_bar = tqdm(ks_dev_reddit_dataloader, desc='Iter (loss=X.XXX)',
                                        disable=args.local_rank not in (-1, 0))
                    total_ks_loss = 0
                    total_ks_kl_loss = 0
                    for ks_dev_step, batch in enumerate(ks_dev_iter_bar):
                        batch = [
                            t.to(device) if t is not None else None for t in batch]

                        input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, labels, ks_labels = batch
                        oracle_pos, oracle_weights, oracle_labels = None, None, None
                        with torch.no_grad():
                            loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next,
                                               masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx,
                                               masked_pos_2=oracle_pos, masked_weights_2=oracle_weights,
                                               masked_labels_2=oracle_labels, mask_qkv=mask_qkv,labels=labels,ks_labels=ks_labels, train_ks=True,train_vae=args.train_vae)
                            ks_loss, KS_KL_loss = loss_tuple

                            if n_gpu > 1:  # mean() to average on multi-gpu.
                                # loss = loss.mean()
                                ks_loss = ks_loss.mean()
                                KS_KL_loss = KS_KL_loss.mean()

                            # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                            ks_dev_iter_bar.set_description('Iter (loss=%5.3f)' % ks_loss.item())
                            total_ks_loss += ks_loss.item()
                            total_ks_kl_loss += KS_KL_loss.item()

                    total_mean_ks_loss = total_ks_loss / (ks_dev_step + 1)
                    total_mean_ks_kl_loss = total_ks_kl_loss / (ks_dev_step + 1)

                    logger.info("** ** * Evaling mean loss ** ** * ")
                    logger.info("In{}epoch,dev_ks_loss:{}".format(i_epoch, total_mean_ks_loss))
                    logger.info("In{}epoch,dev_ks_kl_loss:{}".format(i_epoch, total_mean_ks_kl_loss))

                    total_mean_loss = total_mean_lm_loss + total_mean_kl_loss + total_mean_ks_loss + total_mean_ks_kl_loss
                    logger.info("In{}epoch,dev_loss:{}".format(i_epoch, total_mean_loss))
                    logger.info("******************************************* ")

                    # Save a trained model
                    if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                        logger.info(
                            "** ** * Saving fine-tuned model and optimizer ** ** * ")
                        model_to_save = model.module if hasattr(
                            model, 'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            args.output_dir, "model.{}_{}_{}.bin".format(i_epoch,step,round(total_mean_loss,4)))
                        torch.save(model_to_save.state_dict(), output_model_file)
                        output_optim_file = os.path.join(
                            args.output_dir, "optim.bin")
                        torch.save(optimizer.state_dict(), output_optim_file)

                        logger.info("***** CUDA.empty_cache() *****")
                        torch.cuda.empty_cache()

    if args.do_predict:

        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq_predict(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys(
            )), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids,
                                                      truncate_config={'max_len_a': args.max_len_a,
                                                                       'max_len_b': args.max_len_b,
                                                                       'trunc_seg': args.trunc_seg,
                                                                       'always_truncate_tail': args.always_truncate_tail},
                                                      mask_source_words=args.mask_source_words,
                                                      skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size,
                                                      mask_whole_word=args.mask_whole_word, mode="s2s",
                                                      has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv,
                                                      s2s_special_token=args.s2s_special_token,
                                                      s2s_add_segment=args.s2s_add_segment,
                                                      s2s_share_segment=args.s2s_share_segment,
                                                      pos_shift=args.pos_shift)]

        next_i = 0
        model.eval()

        with open(os.path.join(args.data_dir, args.predict_input_file), "r", encoding="utf-8") as file:
            src_file = file.readlines()
        with open("train_tgt_pad.empty", "r", encoding="utf-8") as file:
            tgt_file = file.readlines()
        with open(os.path.join(args.data_dir, args.predict_output_file), "w", encoding="utf-8") as out:
            while next_i < len(src_file):
                print(next_i)
                batch_src = src_file[next_i:next_i + args.eval_batch_size]
                batch_tgt = tgt_file[next_i:next_i + args.eval_batch_size]

                next_i += args.eval_batch_size

                ex_list = []
                for src, tgt in zip(batch_src, batch_tgt):
                    src_tk = data_tokenizer.tokenize(src.strip())
                    tgt_tk = data_tokenizer.tokenize(tgt.strip())
                    ex_list.append((src_tk, tgt_tk))

                batch = []
                for idx in range(len(ex_list)):
                    instance = ex_list[idx]
                    for proc in bi_uni_pipeline:
                        instance = proc(instance)
                        batch.append(instance)

                batch_tensor = seq2seq_loader.batch_list_to_batch_tensors(batch)
                batch = [
                    t.to(device) if t is not None else None for t in batch_tensor]

                input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch

                predict_bleu = args.predict_bleu * torch.ones([input_ids.shape[0]], device=input_ids.device)  # B
                oracle_pos, oracle_weights, oracle_labels = None, None, None
                with torch.no_grad():
                    logits = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next,
                                   masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx,
                                   masked_pos_2=oracle_pos, masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels, mask_qkv=mask_qkv, labels=predict_bleu, train_ks=True,train_vae=args.train_vae)

                    logits = torch.nn.functional.softmax(logits, dim=1)
                    labels = logits[:, 1].cpu().numpy()
                    # print(labels)
                    for i in range(len(labels)):
                        line = batch_src[i].strip()
                        line += "\t"
                        line += str(labels[i])
                        out.write(line)
                        out.write("\n")
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--bert_model", default="bert-base-cased", type=str,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
    parser.add_argument("--output_dir", default="/Users/lifuh/Documents/Research/squad2.0/output/", type=str,
                        help="The output directory where the model checkpoints and predictions will be written.")

    ## Other parameters
    parser.add_argument("--train_file", default="/Users/lifuh/Documents/Research/squad2.0/train-v2.0.json",
                        type=str, help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument("--predict_file", default="/Users/lifuh/Documents/Research/squad2.0/dev-v2.0.json", type=str,
                        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
    parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.")
    parser.add_argument("--do_predict", default=True, action='store_true', help="Whether to run eval on the dev set.")

    parser.add_argument("--max_seq_length", default=384, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
                             "longer than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--doc_stride", default=128, type=int,
                        help="When splitting up a long document into chunks, how much stride to take between chunks.")
    parser.add_argument("--max_query_length", default=64, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.")
    parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.")
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion", default=0.1, type=float,
                        help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
                             "of training.")
    parser.add_argument("--n_best_size", default=20, type=int,
                        help="The total number of n-best predictions to generate in the nbest_predictions.json "
                             "output file.")
    parser.add_argument("--max_answer_length", default=30, type=int,
                        help="The maximum length of an answer that can be generated. This is needed because the start "
                             "and end predictions are not conditioned on one another.")
    parser.add_argument("--verbose_logging", default=False, action='store_true',
                        help="If true, all of the warnings related to data processing will be printed. "
                             "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Whether to lower case the input text. True for uncased models, False for cased models.")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--fp16',
                        default=False,
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--null_score_diff_threshold',
                        type=float, default=0.0,
                        help="If null_score - best_non_null is greater than the threshold predict null.")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError("At least one of `do_train` or `do_predict` must be True.")

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError("Output directory () already exists and is not empty.")
    os.makedirs(args.output_dir, exist_ok=True)

    # initialize a tokenizer from a pretrained model
    tokenizer = BertTokenizer.from_pretrained(args.bert_model)

    train_examples = None
    num_train_steps = None
    if args.do_train:
        train_examples = read_squad_examples(
            input_file=args.train_file, is_training=True)
        num_train_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    model = BertForQuestionAnswering.from_pretrained(args.bert_model,
                cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    global_step = 0
    if args.do_train:
        cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
            args.bert_model, str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
        train_features = None
        try:
            with open(cached_train_features_file, "rb") as reader:
                train_features = pickle.load(reader)
        except:
            train_features = convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=True)
            if args.local_rank == -1 or torch.distributed.get_rank() == 0:
                logger.info("  Saving train features into cached file %s", cached_train_features_file)
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)
        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
        all_is_impossibles = torch.tensor([int(f.is_impossible) for f in train_features], dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                   all_start_positions, all_end_positions, all_is_impossibles)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                if n_gpu == 1:
                    batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
                input_ids, input_mask, segment_ids, start_positions, end_positions, _ = batch
                loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    # Save a trained model
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    torch.save(model_to_save.state_dict(), output_model_file)

    # Load a trained model that you have fine-tuned
    model_state_dict = torch.load(output_model_file)
    model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
    model.to(device)

    if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = read_squad_examples(
            input_file=args.predict_file, is_training=False)
        eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=False)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(eval_examples))
        logger.info("  Num split examples = %d", len(eval_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start evaluating")
        for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
            if len(all_results) % 1000 == 0:
                logger.info("Processing example: %d" % (len(all_results)))
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            with torch.no_grad():
                batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
            for i, example_index in enumerate(example_indices):
                start_logits = batch_start_logits[i].detach().cpu().tolist()
                end_logits = batch_end_logits[i].detach().cpu().tolist()
                eval_feature = eval_features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                all_results.append(RawResult(unique_id=unique_id,
                                             start_logits=start_logits,
                                             end_logits=end_logits))
        output_prediction_file = os.path.join(args.output_dir, "predictions.json")
        output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
        output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
        write_predictions(eval_examples, eval_features, all_results,
                          args.n_best_size, args.max_answer_length,
                          args.do_lower_case, output_prediction_file,
                          output_nbest_file, output_null_log_odds_file, args.verbose_logging, True, args.null_score_diff_threshold)
Beispiel #20
0
def evaluate(models, result_file):
    with open(args.source_file_name,'rb') as f:
        eval_examples = pickle.load(f)

    with torch.no_grad():
        tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', do_lower_case=True)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        for model in models:
            model.to(device)
            model.eval()
        pred_answers, ref_answers = [], []

        for step, example in enumerate(tqdm(eval_examples)):
            start_probs, end_probs = [], []
            question_text = example['question_text']

            if len(example['doc_tokens']) != 0:
                for p_num, doc_tokens in enumerate(example['doc_tokens'][:args.max_para_num]):
                    (input_ids, input_mask, segment_ids) = \
                        predict_data.predict_data(question_text, doc_tokens['doc_tokens'], tokenizer, args.max_seq_length, args.max_query_length)
                    (input_ids, input_mask, segment_ids) = \
                        input_ids.to(device), input_mask.to(device), segment_ids.to(device)

                    # Ensemble - probabilities are sumed to obtain the ensemble model's probabilities
                    start_prob, end_prob = None, None
                    for model in models:
                        if start_prob is None and end_prob is None:
                            start_prob, end_prob= model(input_ids, segment_ids, attention_mask=input_mask)
                        else:
                            tmp_start_prob, tmp_end_prob =model(input_ids, segment_ids, attention_mask=input_mask)
                            start_prob += tmp_start_prob
                            end_prob   += tmp_end_prob

                    start_prob = start_prob.squeeze(0)
                    end_prob = end_prob.squeeze(0)
                    start_probs.append(start_prob)
                    end_probs.append(end_prob)
                best_answer, docs_index = find_best_answer(example, start_probs, end_probs)
#                best_answer = find_multiple_answers(example, start_probs, end_probs)
            else:
                best_answer = ''

            pred_answer = {'question_id': example['id'],
                                 'question':example['question_text'],
                                 'question_type': example['question_type'],
                                 'answers': [best_answer],
                                 'entity_answers': [[]],
                                 'yesno_answers': []}
            pred_answers.append(pred_answer)
            if 'answers' in example:
                ref_answers.append({'question_id': example['id'],
                                    'question_type': example['question_type'],
                                    'answers': example['answers'],
                                    'entity_answers': [[]],
                                    'yesno_answers': []})
        with open(result_file, 'w', encoding='utf-8') as fout:
            for pred_answer in pred_answers:
                fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n')
            print('save predicted data: ', result_file)
        with open("../metric/ref.json", 'w', encoding='utf-8') as fout:
            for pred_answer in ref_answers:
                fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n')
Beispiel #21
0
def main():
    global logger
    args, task_config = parse_input_parameter()

    random.seed(args.seed)
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    device, n_gpu = init_device(args)

    data_name = args.data_name.lower()

    if data_name in DATASET_DICT:
        args.train_file = DATASET_DICT[data_name]["train_file"]

        if args.do_eval:
            args.valid_file = DATASET_DICT[data_name]["valid_file"]

        if args.do_test:
            args.test_file = DATASET_DICT[data_name]["test_file"]
    else:
        assert args.train_file is not None

        if args.do_eval:
            assert args.valid_file is not None

        if args.do_test:
            assert args.test_file is not None

    task_name = args.task_name.lower()
    if task_name not in DATALOADER_DICT:
        raise ValueError("Task not found: %s" % (task_name))

    if n_gpu > 1 and (args.use_ghl):
        logger.warning("Multi-GPU make the results not reproduce.")

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    # Generate label list from training dataset
    file_path = os.path.join(args.data_dir, args.train_file)
    train_dataloader, train_examples, label_tp_list = DATALOADER_DICT[
        task_name]["train"](args, tokenizer, file_path)

    logging.info("AT Labels are = %s:",
                 "[" + ", ".join(label_tp_list[0]) + "]")
    logging.info("AS Labels are = %s:",
                 "[" + ", ".join(label_tp_list[1]) + "]")
    at_num_labels = len(label_tp_list[0])
    as_num_labels = len(label_tp_list[1])
    num_tp_labels = (at_num_labels, as_num_labels)

    task_config["at_labels"] = label_tp_list[0]
    model = init_model(args, num_tp_labels, task_config, device, n_gpu)

    # Generate test dataset
    logger.info("***** Running test *****")

    if data_name.startswith('joint'):
        # May need to perform testing on more than one domain
        test_dataloaders_list = []
        test_examples_list = []

        for td, test_file_name in enumerate(args.test_file):
            file_path = os.path.join(args.data_dir, test_file_name)
            test_dataloader, test_examples = DATALOADER_DICT[task_name][
                "eval"](args,
                        tokenizer,
                        file_path,
                        label_tp_list=label_tp_list,
                        set_type="test")

            test_dataloaders_list.append(test_dataloader)
            test_examples_list.append(test_examples)

            logger.info("  Domain %d", td)
            logger.info("    Num examples = %d", len(test_examples))
            logger.info("    Batch size = %d", args.eval_batch_size)
    else:
        file_path = os.path.join(args.data_dir, args.test_file)
        test_dataloader, test_examples = DATALOADER_DICT[task_name]["eval"](
            args,
            tokenizer,
            file_path,
            label_tp_list=label_tp_list,
            set_type="test")

        logger.info("  Num examples = %d", len(test_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)

    if args.do_train:
        num_train_optimization_steps = (
            int(len(train_dataloader) + args.gradient_accumulation_steps - 1) /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

        optimizer = prep_optimizer(args, model, num_train_optimization_steps)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        if args.do_eval:
            logger.info("***** Running evaluation *****")
            if data_name.startswith('joint'):
                # May need to perform evaluation on more than one domain
                eval_dataloaders_list = []
                eval_examples_list = []

                for ed, eval_file_name in enumerate(args.valid_file):
                    file_path = os.path.join(args.data_dir, eval_file_name)
                    eval_dataloader, eval_examples = DATALOADER_DICT[
                        task_name]["eval"](args,
                                           tokenizer,
                                           file_path,
                                           label_tp_list=label_tp_list,
                                           set_type="val")
                    eval_dataloaders_list.append(eval_dataloader)
                    eval_examples_list.append(eval_examples)

                    logger.info("  Domain %d", ed)
                    logger.info("    Num examples = %d", len(eval_examples))
                    logger.info("    Batch size = %d", args.eval_batch_size)
            else:
                file_path = os.path.join(args.data_dir, args.valid_file)
                eval_dataloader, eval_examples = DATALOADER_DICT[task_name][
                    "eval"](args,
                            tokenizer,
                            file_path,
                            label_tp_list=label_tp_list,
                            set_type="val")

                logger.info("  Num examples = %d", len(eval_examples))
                logger.info("  Batch size = %d", args.eval_batch_size)

        global_step = 0

        for epoch in range(args.num_train_epochs):
            tr_loss, global_step = train_epoch(epoch, args, model,
                                               train_dataloader, device, n_gpu,
                                               tokenizer, optimizer,
                                               global_step,
                                               num_train_optimization_steps)
            logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1,
                        args.num_train_epochs, tr_loss)
            save_model(epoch, args, model)

            if args.do_eval:
                if data_name.startswith('joint'):
                    for ed, eval_dataloader in enumerate(
                            eval_dataloaders_list):
                        logger.info("  Domain %d", ed)
                        eval_epoch(model, eval_dataloader, label_tp_list,
                                   device)
                else:
                    eval_epoch(model, eval_dataloader, label_tp_list, device)

        if args.do_test:
            logger.info("***Results on test***")
            if data_name.startswith('joint'):
                for td, test_dataloader in enumerate(test_dataloaders_list):
                    logger.info("  Domain %d", td)
                    eval_epoch(model, test_dataloader, label_tp_list, device)
            else:
                eval_epoch(model, test_dataloader, label_tp_list, device)

    elif args.do_test:
        if args.init_model:
            if data_name.startswith('joint'):
                for td, test_dataloader in enumerate(test_dataloaders_list):
                    logger.info("  Domain %d", td)
                    eval_epoch(model, test_dataloader, label_tp_list, device)
            else:
                eval_epoch(model, test_dataloader, label_tp_list, device)
        else:
            for epoch in range(args.num_train_epochs):
                # Load a trained model that you have fine-tuned
                model = load_model(epoch, args, num_tp_labels, task_config,
                                   device)
                if not model:
                    break

                if data_name.startswith('joint'):
                    for td, test_dataloader in enumerate(
                            test_dataloaders_list):
                        logger.info("  Domain %d", td)
                        eval_epoch(model, test_dataloader, label_tp_list,
                                   device)
                else:
                    eval_epoch(model, test_dataloader, label_tp_list, device)
Beispiel #22
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument(
        "--model_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument("--output_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The output results will be written.")
    parser.add_argument("--data_file",
                        default='',
                        type=str,
                        help="The input directory of input data file.")

    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")

    parser.add_argument("--weights",
                        default='',
                        type=int,
                        help="The output results will be written.")
    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    eval_examples = processor.get_dev_examples(args.data_file)
    eval_features = convert_examples_to_features(eval_examples, label_list,
                                                 args.max_seq_length,
                                                 tokenizer)
    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)
    all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                   dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                 dtype=torch.long)
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                              all_label_ids)
    # Run prediction for full data
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    for i in range(args.weights):
        print("epoch:{}".format(i))
        #if args.weights != '':
        #WEIGHTS_NAME = "epoch"+str(args.weights)+"_"+WEIGHTS_NAME
        output_model_file = os.path.join(args.model_dir,
                                         "epoch" + str(i) + "_" + WEIGHTS_NAME)
        #else:
        #output_model_file = os.path.join(args.model_dir, WEIGHTS_NAME)

        output_config_file = os.path.join(args.model_dir, CONFIG_NAME)
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        #model = BertForLMClassification(config, num_labels=num_labels)
        #model = BertForDClassifier(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
        model.to(device)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        result = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy}

        output_eval_file = args.output_file
        with open(output_eval_file, "a") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Beispiel #23
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--sample",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument("--word_embedding_file",
                        default='./emb/wiki-news-300d-1M.vec',
                        type=str,
                        help="The input directory of word embeddings.")
    parser.add_argument("--index_path",
                        default='./emb/p_index.bin',
                        type=str,
                        help="The input directory of word embedding index.")
    parser.add_argument("--word_embedding_info",
                        default='./emb/vocab_info.txt',
                        type=str,
                        help="The input directory of word embedding info.")
    parser.add_argument("--data_file",
                        default='',
                        type=str,
                        help="The input directory of input data file.")

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--max_ngram_length",
                        default=16,
                        type=int,
                        help="The maximum total ngram sequence")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--embedding_size",
                        default=300,
                        type=int,
                        help="Total batch size for embeddings.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--num_eval_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of eval epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--single',
                        action='store_true',
                        help="Whether only evaluate a single epoch")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                              do_lower_case=args.do_lower_case)

    # Parameters directly assignment in the code
    #DATA_DIR = 'data'
    DATA_DIR = './data/sst-2/train.tsv'  # For debugger
    IN_TEXT = 'cleaned_haiku.data'
    IN_W2V = 'w2v_haiku.model'

    # to-do: think of an open vocabulary system
    WORD_LIMIT = 9999  # remaining 1 for <PAD> (this is inclusive of UNK)
    #task_name = ""
    TARGET_PAD_IDX = -1
    INPUT_PAD_IDX = 0

    keyboard_mappings = None

    text = encoder = None
    print("args.num_train_epochs  :", args.num_train_epochs)

    def _create_examples(lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            #print("line :", line)
            #print("line[0]: ", line[0])
            #print("line[1]", line[1])
            flaw_labels = None
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text_a = line[0]
            label = line[1]
            if len(line) > 2: flaw_labels = line[2]
            examples.append(
                InputExample(guid=guid,
                             text_a=text_a,
                             text_b=None,
                             label=label,
                             flaw_labels=flaw_labels))
        return examples

    def _create_examples_clean(lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            flaw_labels = None
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text_raw = line[0]
            text_a = clean_text(text_raw)
            label = line[1]
            if len(line) > 2: flaw_labels = line[2]
            examples.append(
                InputExample(guid=guid,
                             text_a=text_a,
                             text_b=None,
                             label=label,
                             flaw_labels=flaw_labels))
        return examples

    def _read_tsv(input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                if sys.version_info[0] == 2:
                    line = list(unicode(cell, 'utf-8') for cell in line)
                lines.append(line)
            return lines

    def get_train_examples(DATA_DIR):
        """See base class."""
        if 'tsv' in DATA_DIR:
            return _create_examples(_read_tsv(DATA_DIR), "train")
        else:
            return _create_examples(
                _read_tsv(os.path.join(DATA_DIR, "train.tsv")), "train")

    def get_train_examples_clean(DATA_DIR):
        """See base class."""
        if 'tsv' in DATA_DIR:
            return _create_examples_clean(_read_tsv(DATA_DIR), "train")
        else:
            return _create_examples_clean(
                _read_tsv(os.path.join(DATA_DIR, "train.tsv")), "train")

    def get_text_from_train_examples(train_examples):
        all_text_train_examples = []
        for example in train_examples:
            all_text_train_examples.append(example.text_a)

        return all_text_train_examples

    def get_text_and_labels_train_examples(train_examples):
        all_text_train_examples = []
        all_labels_train_examples = []
        label_map = {label: i for i, label in enumerate(label_list)}
        for example in train_examples:
            all_text_train_examples.append(example.text_a)
            all_labels_train_examples.append(label_map[example.label])

        return all_text_train_examples, all_labels_train_examples

    text = None
    encoder = None

    def get_data():

        #train_examples = get_train_examples(args.data_dir)
        train_examples = get_train_examples_clean(args.data_dir)
        text, labels = get_text_and_labels_train_examples(train_examples)

        logger.info("Loading word embeddings ...in Gensim format ")

        text_new = [txt.split() for txt in text]

        encoder = FastText(text_new, min_count=1, size=128)

        return text_new, text, encoder, labels

    def get_lines(start, end, text, encoder):

        text = text
        encoder = encoder

        seq_lens = []
        sentences = []
        longest = 0
        text_batch = []
        for i in range((end - start)):
            text_batch.append(text[i])
        for l in text_batch:
            seq_lens.append(len(l))
            longest = len(l) if len(l) > longest else longest
            sentence = []
            for txt in l:
                sentence.append(torch.tensor(encoder.wv[txt]))
            sentences.append(torch.stack(sentence).unsqueeze(0))

        # Pad input
        d_size = sentences[0].size(2)
        print("sentences: ", type(sentences))
        print("sentences len: ", len(sentences))
        for i in range(len(sentences)):
            sl = sentences[i].size(1)

            if sl < longest:
                sentences[i] = torch.cat(
                    [sentences[i],
                     torch.zeros(1, longest - sl, d_size)],
                    dim=1)

        # Need to squish sentences into [0,1] domain
        seq = torch.cat(sentences, dim=0)
        # seq = torch.sigmoid(seq)
        start_words = seq[:, 0:1, :]
        packer = pack_padded_sequence(seq,
                                      seq_lens,
                                      batch_first=True,
                                      enforce_sorted=False)
        return packer, start_words

    """ word representation from bag of chars
    """

    def get_boc_word_representation(word):
        return zero_vector() + bag_of_chars(word) + zero_vector()

    def one_hot(char):
        return [1.0 if ch == char else 0.0 for ch in CHAR_VOCAB]

    def bag_of_chars(chars):
        return [float(chars.count(ch)) for ch in CHAR_VOCAB]

    def zero_vector():
        return [0.0 for _ in CHAR_VOCAB]

    """ word representation from individual chars
        one hot (first char) + bag of chars (middle chars) + one hot (last char)
    """

    def get_swap_word_representation(word):

        # dirty case
        if len(word) == 1 or len(word) == 2:
            rep = one_hot(word[0]) + zero_vector() + one_hot(
                word[-1])  # Return value used
            return rep, word

        rep = one_hot(word[0]) + bag_of_chars(word[1:-1]) + one_hot(
            word[-1])  # Return value used
        if len(word) > 3:
            idx = random.randint(1, len(word) - 3)
            word = word[:idx] + word[idx + 1] + word[idx] + word[
                idx + 2:]  # return value not used

        return rep, word

    def get_closest(sentences):
        scores = []
        wv = encoder.wv
        for s in sentences.detach().numpy():
            st = [
                wv[wv.most_similar([s[i]], topn=1)[0][0]]
                for i in range(s.shape[0])
            ]
            scores.append(torch.tensor(st))

        return torch.stack(scores, dim=0)

    torch.set_num_threads(16)
    device = torch.device("cpu")

    sample_task = args.sample.lower()

    def adversarial_attacks(start, end, encoder):

        train_examples_batch = processor.get_train_examples_for_attacks(
            args.data_dir, start, end)
        features_for_attacks, w2i_disp, i2w_disp, vocab_size = convert_examples_to_features_gan2vec(
            train_examples_batch, label_list, tokenizer=None, max_seq_length=6)

        all_tokens = torch.tensor([f.token_ids for f in features_for_attacks],
                                  dtype=torch.long)
        all_label_id = torch.tensor([f.label_id for f in features_for_attacks],
                                    dtype=torch.long)

        data_for_attacks = TensorDataset(all_tokens, all_label_id)
        sampler_for_attacks = SequentialSampler(data_for_attacks)

        dataloader_for_attack = DataLoader(data_for_attacks,
                                           sampler=sampler_for_attacks)

        all_batch_flaw_tokens = []
        all_batch_flaw_labels = []
        all_batch_flaw_labels_truth = []
        flaw_labels_lst = []
        for step, batch in enumerate(
                tqdm(dataloader_for_attack, desc="Iteration")):
            batch = tuple(t.to(device) for t in batch)
            tokens, _ = batch  # , label_id, ngram_ids, ngram_labels, ngram_masks

            tokens = tokens.to('cpu').numpy()

            features_with_flaws, all_flaw_tokens, all_token_idx, all_truth_tokens, all_flaw_labels_truth = convert_examples_to_features_flaw_attacks_gr(
                tokens,
                args.max_seq_length,
                args.max_ngram_length,
                i2w,
                tokenizer,
                embeddings=None,
                emb_index=None,
                words=None)

            all_token_idx = ",".join(
                [str(id) for tok in all_token_idx for id in tok])
            all_truth_tokens_flat = ' '.join(
                [str(id) for tok in all_truth_tokens for id in tok])

            flaw_ids = torch.tensor([f.flaw_ids for f in features_with_flaws])
            flaw_labels = torch.tensor(
                [f.flaw_labels for f in features_with_flaws])

            print("all_flaw_tokens : before before ", all_flaw_tokens)
            all_flaw_tokens = ' '.join(
                [str(y) for x in all_flaw_tokens for y in x])

            flaw_ids_ar = flaw_ids.detach().cpu().numpy()
            flaw_ids_lst = flaw_ids.tolist()
            flaw_labels_ar = flaw_labels.detach().cpu().numpy()
            flaw_labels_lst = flaw_labels.tolist()
            all_flaw_tokens = all_flaw_tokens.strip("''").strip("``")
            all_truth_tokens_flat = all_truth_tokens_flat.strip("''").strip(
                "``")
            all_batch_flaw_tokens.append(all_flaw_tokens)
            all_batch_flaw_labels.append(flaw_labels_lst)
            all_batch_flaw_labels_truth.append(all_flaw_labels_truth)

        batch_tx = []
        BATCH_SEQ_LEN = []
        Xtype = torch.FloatTensor
        for line in all_batch_flaw_tokens:
            SEQ_LEN = len(line.split())
            line = line.lower()
            X = get_target_representation(line, encoder)
            batch_tx.append(X)
            BATCH_SEQ_LEN.append(SEQ_LEN)
        X_t = torch.tensor(batch_tx, dtype=torch.float)
        real_adv = pack_padded_sequence(X_t, BATCH_SEQ_LEN, batch_first=True)
        all_batch_flaw_labels_truth_t = torch.tensor(
            all_batch_flaw_labels_truth, dtype=torch.long)
        all_batch_flaw_labels_truth_t_s = torch.squeeze(
            all_batch_flaw_labels_truth_t)
        return X_t, all_batch_flaw_labels_truth_t_s

    def get_loss():
        return loss_nll, loss_nll

    def train(epochs, batch_size=256, latent_size=256, K=1):
        text, text_orig, encoder, labels = get_data()
        num_samples = len(text)
        create_vocab(args.data_dir, text_orig)
        G = Generator(128, 128)
        D = Discriminator(128, len(CHAR_VOCAB), encoder)

        l2 = nn.MSELoss()
        loss = get_loss()
        opt_d = Adam(D.parameters(), lr=0.002, betas=(0.5, 0.999))
        opt_g = Adam(G.parameters(), lr=0.002, betas=(0.5, 0.999))

        max_seq_len = args.max_seq_length
        for e in range(epochs):
            i = 0

            while batch_size * i < num_samples:
                stime = time.time()

                start = batch_size * i
                end = min(batch_size * (i + 1), num_samples)
                bs = end - start
                # Fixed labels
                test_seq_length = 6
                zeros = torch.zeros(bs, test_seq_length, dtype=torch.long)
                ones = torch.ones(bs, test_seq_length, dtype=torch.long)

                tl = torch.full((bs, 1), 0.9)
                fl = torch.full((bs, 1), 0.1)

                real, greal = get_lines(start, end, text, encoder)

                # Train Generator as per RobGAN
                for _ in range(K):
                    opt_g.zero_grad()

                    # GAN fooling ability
                    fake = G(greal)
                    d_fake_bin, d_fake_multi = D(fake)
                    g_loss = loss_nll(d_fake_bin,
                                      tl,
                                      d_fake_multi,
                                      zeros,
                                      lam=0.5)
                    g_loss.backward()
                    opt_g.step()

                g_loss = g_loss.item()

                # Train descriminator
                opt_d.zero_grad()
                fake = G(greal)

                d_fake_bin_d, d_fake_multi_d = D(fake)
                d_f_loss = loss_nll(d_fake_bin_d,
                                    fl,
                                    d_fake_multi_d,
                                    ones,
                                    lam=0.5)
                real_adv, flaw_labels = adversarial_attacks(
                    start, end, encoder)

                d_adv_bin, d_adv_multi = D(real_adv)
                d_adv_loss = loss_nll(d_adv_bin,
                                      tl,
                                      d_adv_multi,
                                      flaw_labels,
                                      lam=0.5)
                d_loss_total = d_adv_loss + d_f_loss

                d_loss_total.backward()
                opt_d.step()

                i += 1

        torch.save(D, 'Discriminator.model')
        torch.save(G, 'generator.model')

    if sample_task == 'developing':
        train(10, batch_size=256)
Beispiel #24
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument("--word_embedding_file",
                        default='emb/crawl-300d-2M.vec',
                        type=str,
                        help="The input directory of word embeddings.")
    parser.add_argument("--index_path",
                        default='emb/p_index.bin',
                        type=str,
                        help="The input directory of word embedding index.")
    parser.add_argument("--word_embedding_info",
                        default='emb/vocab_info.txt',
                        type=str,
                        help="The input directory of word embedding info.")
    parser.add_argument("--data_file",
                        default='',
                        type=str,
                        help="The input directory of input data file.")

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--max_ngram_length",
                        default=16,
                        type=int,
                        help="The maximum total ngram sequence")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--embedding_size",
                        default=300,
                        type=int,
                        help="Total batch size for embeddings.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--num_eval_epochs',
        type=int,
        default=0,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--single',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    logger.info("loading embeddings ... ")
    if args.do_train:
        emb_dict, emb_vec, vocab_list, emb_vocab_size = load_vectors(
            args.word_embedding_file)
        write_vocab_info(args.word_embedding_info, emb_vocab_size, vocab_list)
    if args.do_eval:
        emb_vocab_size, vocab_list = load_vocab_info(args.word_embedding_info)
        #emb_dict, emb_vec, vocab_list, emb_vocab_size = load_vectors(args.word_embedding_file)
        #write_vocab_info(args.word_embedding_info, emb_vocab_size, vocab_list)
    logger.info("loading p index ...")
    if not os.path.exists(args.index_path):
        p = load_embeddings_and_save_index(range(emb_vocab_size), emb_vec,
                                           args.index_path)
    else:
        p = load_embedding_index(args.index_path,
                                 emb_vocab_size,
                                 num_dim=args.embedding_size)

    train_examples = None
    num_train_optimization_steps = None
    w2i, i2w, vocab_size = {}, {}, 1
    if args.do_train:

        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

        train_features, w2i, i2w, vocab_size = convert_examples_to_features_gnrt_train(\
            train_examples, label_list, args.max_seq_length, args.max_ngram_length, tokenizer, emb_dict)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Num token vocab = %d", vocab_size)
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_ngram_ids = torch.tensor([f.ngram_ids for f in train_features],
                                     dtype=torch.long)
        all_ngram_labels = torch.tensor(
            [f.ngram_labels for f in train_features], dtype=torch.long)
        all_ngram_masks = torch.tensor([f.ngram_masks for f in train_features],
                                       dtype=torch.long)
        all_ngram_embeddings = torch.tensor(
            [f.ngram_embeddings for f in train_features], dtype=torch.float)

        # Prepare model
        cache_dir = args.cache_dir if args.cache_dir else os.path.join(
            PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(
                args.local_rank))
        model = BertForNgramClassification.from_pretrained(
            args.bert_model,
            cache_dir=cache_dir,
            num_labels=num_labels,
            embedding_size=args.embedding_size,
            max_seq_length=args.max_seq_length,
            max_ngram_length=args.max_ngram_length)
        model.to(device)
        if args.local_rank != -1:
            try:
                from apex.parallel import DistributedDataParallel as DDP
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            model = DDP(model)
        elif n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Prepare optimizer
        param_optimizer = list(model.named_parameters())

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

        global_step = 0
        nb_tr_steps = 0
        tr_loss = 0

        #if args.do_train:

        train_data = TensorDataset(all_ngram_ids, all_ngram_labels,
                                   all_ngram_masks, all_ngram_embeddings)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for ind in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_steps = 0

            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                ngram_ids, ngram_labels, ngram_masks, ngram_embeddings = batch
                loss = model(ngram_ids, ngram_masks, ngram_embeddings)
                if n_gpu > 1:
                    loss = loss.mean()

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()
                tr_loss += loss.item()

                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:

                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            loss = tr_loss / nb_tr_steps if args.do_train else None
            result = {
                'loss': loss,
            }

            output_eval_file = os.path.join(args.output_dir,
                                            "train_results.txt")
            with open(output_eval_file, "a") as writer:
                #logger.info("***** Training results *****")
                writer.write("epoch" + str(ind) + '\n')
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
                writer.write('\n')

            model_to_save = model.module if hasattr(model, 'module') else model
            output_model_file = os.path.join(args.output_dir,
                                             "epoch" + str(ind) + WEIGHTS_NAME)
            torch.save(model_to_save.state_dict(), output_model_file)
            output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
            with open(output_config_file, 'w') as f:
                f.write(model_to_save.config.to_json_string())

    # Load a trained model and config that you have fine-tuned
    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):

        eval_examples = processor.get_gnrt_dev_examples(args.data_file)
        eval_features, w2i, i2w, vocab_size = convert_examples_to_features_gnrt_eval(
            eval_examples, label_list, args.max_seq_length,
            args.max_ngram_length, tokenizer, w2i, i2w, vocab_size)

        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Num token vocab = %d", vocab_size)
        logger.info("  Batch size = %d", args.eval_batch_size)

        all_token_ids = torch.tensor([f.token_ids for f in eval_features],
                                     dtype=torch.long)
        # all_flaw_labels: indexes of wrong words predicted by disc
        all_flaw_labels = torch.tensor([f.flaw_labels for f in eval_features],
                                       dtype=torch.long)
        all_ngram_ids = torch.tensor([f.ngram_ids for f in eval_features],
                                     dtype=torch.long)
        all_ngram_mask = torch.tensor([f.ngram_mask for f in eval_features],
                                      dtype=torch.long)
        all_ngram_labels = torch.tensor(
            [f.ngram_labels for f in eval_features], dtype=torch.long)
        all_label_id = torch.tensor([f.label_id for f in eval_features],
                                    dtype=torch.long)

        eval_data = TensorDataset(all_token_ids, all_ngram_ids, all_ngram_mask,
                                  all_ngram_labels, all_label_id,
                                  all_flaw_labels)

        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        if args.single:
            eval_range = trange(int(args.num_eval_epochs),
                                int(args.num_eval_epochs + 1),
                                desc="Epoch")
        else:
            eval_range = trange(int(args.num_eval_epochs), desc="Epoch")

        for epoch in eval_range:

            output_file = os.path.join(
                args.data_dir, "epoch" + str(epoch) + "gnrt_outputs.tsv")
            with open(output_file, "w") as csv_file:
                writer = csv.writer(csv_file, delimiter='\t')
                writer.writerow(["sentence", "label"])

            output_model_file = os.path.join(
                args.output_dir, "epoch" + str(epoch) + WEIGHTS_NAME)
            output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
            config = BertConfig(output_config_file)
            model = BertForNgramClassification(
                config,
                num_labels=num_labels,
                embedding_size=args.embedding_size,
                max_seq_length=args.max_seq_length,
                max_ngram_length=args.max_ngram_length)
            model.load_state_dict(torch.load(output_model_file))
            model.to(device)
            model.eval()

            for token_ids, ngram_ids, ngram_mask, ngram_labels, label_id, flaw_labels in tqdm(
                    eval_dataloader, desc="Evaluating"):

                ngram_ids = ngram_ids.to(device)
                ngram_mask = ngram_mask.to(device)

                with torch.no_grad():
                    logits = model(ngram_ids, ngram_mask)

                logits = logits.detach().cpu().numpy()
                flaw_labels = flaw_labels.to('cpu').numpy()
                label_id = label_id.to('cpu').numpy()
                token_ids = token_ids.to('cpu').numpy()
                masks = ngram_mask.to('cpu').numpy()

                with open(output_file, "a") as csv_file:

                    for i in range(len(label_id)):

                        correct_tokens = look_up_words(logits[i], masks[i],
                                                       vocab_list, p)
                        token_new = replace_token(token_ids[i], flaw_labels[i],
                                                  correct_tokens, i2w)
                        token_new = ' '.join(token_new)
                        label = str(label_id[i])
                        writer = csv.writer(csv_file, delimiter='\t')
                        writer.writerow([token_new, label])
Beispiel #25
0
def main():
    print(torch.cuda.is_available())
    parser = argparse.ArgumentParser()
    ## Required parameters
    parser.add_argument(
        "--input_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input train corpus. can be directory with .txt files or a path to a single file"
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output file where the model checkpoints will be written.")

    ## Other parameters

    # bool
    parser.add_argument(
        "--mode",
        type=str,
    )

    # str
    parser.add_argument(
        "--bert_model",
        default="bert-large-uncased",
        type=str,
        required=False,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--task_name", default="", type=str, required=False)
    parser.add_argument("--local_rank", default=0, type=int)

    # int
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--dupe_factor",
        default=1,
        type=int,
        help=
        "Number of times to duplicate the input data (with different masks).")
    parser.add_argument("--max_predictions_per_seq",
                        default=20,
                        type=int,
                        help="Maximum sequence length.")
    parser.add_argument("--sentence_batch_size", default=32, type=int)
    parser.add_argument("--top_sen_rate", default=0.8, type=float)
    parser.add_argument("--threshold", default=0.2, type=float)

    # floats

    parser.add_argument("--masked_lm_prob",
                        default=0.15,
                        type=float,
                        help="Masked LM probability.")

    parser.add_argument(
        "--short_seq_prob",
        default=0.1,
        type=float,
        help=
        "Probability to create a sequence shorter than maximum sequence length"
    )

    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument('--random_seed',
                        type=int,
                        default=12345,
                        help="random seed for initialization")
    parser.add_argument('--part', type=int, default=0)
    parser.add_argument('--max_proc', type=int, default=1)
    parser.add_argument('--with_rand', action='store_true')
    parser.add_argument('--split_part', type=int)

    args = parser.parse_args()
    print(args)
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    rng = random.Random(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    print("creating instance from {}".format(args.input_dir))
    processor = processors[args.task_name]()
    eval_examples = processor.get_pretrain_examples(args.input_dir, args.part,
                                                    args.max_proc)
    if args.task_name == "absa" or args.task_name == "absa_term":
        data = eval_examples
        all_labels = None
    else:
        data = [example.text_a for example in eval_examples]
        all_labels = [example.label for example in eval_examples]

    del eval_examples

    label_list = processor.get_labels()
    logger.info("Bert Model: {}".format(args.bert_model))

    if args.mode == "rand":
        print("Mode: rand")
        generator = RandMask(args.masked_lm_prob, args.bert_model,
                             args.do_lower_case, args.max_seq_length)
    elif args.mode == "rule":
        print("Mode: rule")
        if args.task_name == "absa" or args.task_name == "absa_term":
            generator = ASC(args.masked_lm_prob, args.top_sen_rate,
                            args.threshold, args.bert_model,
                            args.do_lower_case, args.max_seq_length,
                            label_list, args.sentence_batch_size)
        else:
            generator = SC(args.masked_lm_prob, args.top_sen_rate,
                           args.threshold, args.bert_model, args.do_lower_case,
                           args.max_seq_length, label_list,
                           args.sentence_batch_size)
    else:
        print("Mode: model")
        generator = ModelGen(args.masked_lm_prob,
                             args.bert_model,
                             args.do_lower_case,
                             args.max_seq_length,
                             args.sentence_batch_size,
                             with_rand=args.with_rand)

    if args.with_rand:
        instances, rand_instances, labeled_data = create_training_instances(
            data,
            all_labels,
            args.task_name,
            generator,
            args.max_seq_length,
            args.dupe_factor,
            args.short_seq_prob,
            args.masked_lm_prob,
            args.max_predictions_per_seq,
            rng,
            with_rand=args.with_rand)
    else:
        instances, labeled_data = create_training_instances(
            data,
            all_labels,
            args.task_name,
            generator,
            args.max_seq_length,
            args.dupe_factor,
            args.short_seq_prob,
            args.masked_lm_prob,
            args.max_predictions_per_seq,
            rng,
            with_rand=args.with_rand)

    if args.part >= 0:
        output_file = os.path.join(args.output_dir, "model",
                                   "{}.hdf5".format(args.part))
        if args.with_rand:
            rand_output_file = os.path.join(args.output_dir, "rand",
                                            "{}.hdf5".format(args.part))
        labeled_output_file = os.path.join(args.output_dir,
                                           "{}.pkl".format(args.part))
    else:
        output_file = os.path.join(args.output_dir, "model", "0.hdf5")
        if args.with_rand:
            rand_output_file = os.path.join(args.output_dir, "rand", "0.hdf5")
        labeled_output_file = os.path.join(args.output_dir, "0.pkl")

    if args.mode == "rule":
        print("Writing labeled data(.pkl) for rule mode")
        write_labeled_data(labeled_data, labeled_output_file)
    else:
        print("Writing masked data(.hdf5) for model mode")
        if args.with_rand:
            print("Num instances: {}. Num rand instance: {}".format(
                len(instances), len(rand_instances)))
            write_instance_to_example_file(instances, tokenizer,
                                           args.max_seq_length,
                                           args.max_predictions_per_seq,
                                           output_file)
            write_instance_to_example_file(rand_instances, tokenizer,
                                           args.max_seq_length,
                                           args.max_predictions_per_seq,
                                           rand_output_file)
        else:
            print("Num instances: {}.".format(len(instances)))
            write_instance_to_example_file(instances, tokenizer,
                                           args.max_seq_length,
                                           args.max_predictions_per_seq,
                                           labeled_output_file)
Beispiel #26
0
def main():
    parser = argparse.ArgumentParser()
    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        choices=["WSD"],
                        help="The name of the task to train.")
    parser.add_argument("--train_data_dir",
                        default=None,
                        type=str,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--eval_data_dir",
                        default=None,
                        type=str,
                        help="The label data dir. (./wordnet)")
    parser.add_argument("--label_data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The label data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model checkpoints will be written.")
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help='''a path or url to a pretrained model archive containing:
                        'bert_config.json' a configuration file for the model
                        'pytorch_model.bin' a PyTorch dump of a BertForPreTraining instance''')
    
    ## Other parameters
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")        
    parser.add_argument("--do_test",
                        action='store_true',
                        help="Whether to run test on the test set.")            
    parser.add_argument("--do_lower_case",
                        default=False,
                        action='store_true',
                        help="Whether to lower case the input text. True for uncased models, False for cased models.")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed', 
                        type=int, 
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumualte before performing a backward/update pass.")                       
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')


    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt = '%m/%d/%Y %H:%M:%S',
                        level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    
    if not args.do_train and not args.do_test:
        raise ValueError("At least one of `do_train` or `do_test` must be True.")
    if args.do_train:
        assert args.train_data_dir != None, "train_data_dir can not be None"
    if args.do_eval:
        assert args.eval_data_dir != None, "eval_data_dir can not be None"

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    # prepare dataloaders
    processors = {
        "WSD":WSDProcessor
    }

    output_modes = {
        "WSD": "classification"
    }

    processor = processors[args.task_name]()
    output_mode = output_modes[args.task_name]
    label_list = processor.get_labels(args.label_data_dir)
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    # training set
    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.train_data_dir, args.label_data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()


    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
    model = BertForTokenClassification.from_pretrained(args.bert_model,
              cache_dir=cache_dir,
              num_labels=num_labels)
 
    if args.fp16:
        model.half()
    model.to(device)

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    
    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)



    # load data
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_target_mask = torch.tensor([f.target_mask for f in train_features], dtype=torch.long)
        all_index_start = torch.tensor([f.index_start for f in train_features], dtype=torch.long)
        all_index_end = torch.tensor([f.index_end for f in train_features], dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)

        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_target_mask, all_index_start, all_index_end)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)


    if args.do_eval:
        eval_examples = processor.get_dev_examples(args.eval_data_dir, args.label_data_dir)
        eval_features = convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_target_mask = torch.tensor([f.target_mask for f in eval_features], dtype=torch.long)
        all_label_mask = torch.tensor([f.label_mask for f in eval_features], dtype=torch.float)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)

        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_target_mask, all_label_mask)
        eval_dataloader = DataLoader(eval_data, batch_size=args.eval_batch_size, shuffle=False)




    # train
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0

    if args.do_train:
        model.train()
        epoch = 0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            epoch += 1
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids, target_mask, index_start, index_end = batch

                all_label_mask = []
                for i in range(len(index_start)):
                    label_mask = [float("-inf")] * len(label_list)
                    for i in range(index_start[i][0].item(), index_end[i][0].item()):
                        label_mask[i] = 0
                    all_label_mask.append(label_mask)
                
                all_label_mask = torch.tensor(all_label_mask, dtype=torch.float).to(device)

                logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None, target_mask=target_mask)

                logits = logits + all_label_mask
                logits = F.softmax(logits, dim=-1)
                
                
                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
                

            # Save a trained model, configuration and tokenizer
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self

            # If we save using the predefined names, we can load using `from_pretrained`
            model_output_dir = os.path.join(args.output_dir, str(epoch))
            if not os.path.exists(model_output_dir):
                os.makedirs(model_output_dir)
            output_model_file = os.path.join(model_output_dir, WEIGHTS_NAME)
            output_config_file = os.path.join(model_output_dir, CONFIG_NAME)

            torch.save(model_to_save.state_dict(), output_model_file)
            model_to_save.config.to_json_file(output_config_file)
            tokenizer.save_vocabulary(model_output_dir)



            if args.do_eval:
                model.eval()
                eval_loss, eval_accuracy = 0, 0
                nb_eval_steps, nb_eval_examples = 0, 0

                with open(os.path.join(args.output_dir, "results_"+str(epoch)+".txt"),"w") as f:
                    for input_ids, input_mask, segment_ids, label_ids, target_mask, label_mask in tqdm(eval_dataloader, desc="Evaluating"):
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        label_ids = label_ids.to(device)
                        target_mask = target_mask.to(device)
                        label_mask = label_mask.to(device)

                        with torch.no_grad():
                            logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None, target_mask=target_mask)

                        logits = logits + label_mask
                        logits_ = F.softmax(logits, dim=-1)
                        logits_ = logits_.detach().cpu().numpy()
                        label_ids_ = label_ids.to('cpu').numpy()
                        outputs = np.argmax(logits_, axis=1)
                        for output_i in range(len(outputs)):
                            f.write(str(outputs[output_i]))
                            f.write("\n")
                        tmp_eval_accuracy = np.sum(outputs == label_ids_)

                        # create eval loss and other metric required by the task
                        if output_mode == "classification":
                            loss_fct = CrossEntropyLoss()
                            tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
                        elif output_mode == "regression":
                            loss_fct = MSELoss()
                            tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))
                        
                        eval_loss += tmp_eval_loss.mean().item()
                        eval_accuracy += tmp_eval_accuracy
                        nb_eval_examples += input_ids.size(0)
                        nb_eval_steps += 1

                eval_loss = eval_loss / nb_eval_steps
                eval_accuracy = eval_accuracy / nb_eval_examples
                loss = tr_loss/nb_tr_steps if args.do_train else None

                result = OrderedDict()
                result['eval_loss'] = eval_loss
                result['eval_accuracy'] = eval_accuracy
                result['global_step'] = global_step
                result['loss'] = loss

                output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
                with open(output_eval_file, "a+") as writer:
                    writer.write("epoch=%s\n"%str(epoch))
                    logger.info("***** Eval results *****")
                    for key in result.keys():
                        logger.info("  %s = %s", key, str(result[key]))
                        writer.write("%s = %s\n" % (key, str(result[key])))




    if args.do_test and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.eval_data_dir, args.label_data_dir)
        eval_features = convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_target_mask = torch.tensor([f.target_mask for f in eval_features], dtype=torch.long)
        all_label_mask = torch.tensor([f.label_mask for f in eval_features], dtype=torch.float)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)

        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_target_mask, all_label_mask)
        eval_dataloader = DataLoader(eval_data, batch_size=args.eval_batch_size, shuffle=False)



        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0

        with open(os.path.join(args.output_dir, "results.txt"),"w") as f:
            for input_ids, input_mask, segment_ids, label_ids, target_mask, label_mask in tqdm(eval_dataloader, desc="Evaluating"):
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)
                target_mask = target_mask.to(device)
                label_mask = label_mask.to(device)

                with torch.no_grad():
                    logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None, target_mask=target_mask)

                logits = logits + label_mask
                logits_ = F.softmax(logits, dim=-1)
                logits_ = logits_.detach().cpu().numpy()
                label_ids_ = label_ids.to('cpu').numpy()
                outputs = np.argmax(logits_, axis=1)
                for output_i in range(len(outputs)):
                    f.write(str(outputs[output_i]))
                    f.write("\n")
                tmp_eval_accuracy = np.sum(outputs == label_ids_)

                # create eval loss and other metric required by the task
                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))
                
                eval_loss += tmp_eval_loss.mean().item()
                eval_accuracy += tmp_eval_accuracy
                nb_eval_examples += input_ids.size(0)
                nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss/nb_tr_steps if args.do_train else None

        result = OrderedDict()
        result['eval_loss'] = eval_loss
        result['eval_accuracy'] = eval_accuracy
        result['global_step'] = global_step
        result['loss'] = loss

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "a+") as writer:
            logger.info("***** Eval results *****")
            for key in result.keys():
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Beispiel #27
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")

    parser.add_argument("--train_vae",
                        action='store_true',
                        help="Whether to train vae.")

    parser.add_argument('--bleu', type=float, default=0.2, help="Set Bleu ")

    # decoding parameters
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument("--input_file", type=str, help="Input file")
    parser.add_argument('--subset',
                        type=int,
                        default=0,
                        help="Decode a subset of the input dataset.")
    parser.add_argument("--output_file", type=str, help="output file")
    parser.add_argument("--split",
                        type=str,
                        default="",
                        help="Data split (train/val/test).")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Ignore the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=1, type=int)
    parser.add_argument('--need_score_traces', action='store_true')
    parser.add_argument('--ngram_size', type=int, default=1)
    parser.add_argument('--mode',
                        default="s2s",
                        choices=["s2s", "l2r", "both"])
    parser.add_argument('--max_tgt_length',
                        type=int,
                        default=128,
                        help="maximum length of target sequence")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument('--not_predict_token',
                        type=str,
                        default=None,
                        help="Do not predict the tokens during decoding.")

    args = parser.parse_args()

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1."
        )
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    tokenizer.max_len = args.max_seq_length

    pair_num_relation = 0
    bi_uni_pipeline = []
    bi_uni_pipeline.append(
        seq2seq_loader.Preprocess4Seq2seqDecoder(
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            max_tgt_length=args.max_tgt_length,
            new_segment_ids=args.new_segment_ids,
            mode="s2s",
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift))

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 + \
          (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[S2S_SOS]"])

    def _get_token_id_set(s):
        r = None
        if s:
            w_list = []
            for w in s.split('|'):
                if w.startswith('[') and w.endswith(']'):
                    w_list.append(w.upper())
                else:
                    w_list.append(w)
            r = set(tokenizer.convert_tokens_to_ids(w_list))
        return r

    forbid_ignore_set = _get_token_id_set(args.forbid_ignore_word)
    not_predict_set = _get_token_id_set(args.not_predict_token)
    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=pair_num_relation,
            type_vocab_size=type_vocab_size,
            task_idx=3,
            mask_word_id=mask_word_id,
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            sos_id=sos_word_id,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            not_predict_set=not_predict_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            mode=args.mode,
            max_position_embeddings=args.max_seq_length,
            ffn_type=args.ffn_type,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            pos_shift=args.pos_shift)
        del model_recover

        if args.fp16:
            model.half()
        model.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        torch.cuda.empty_cache()
        model.eval()
        next_i = 0
        max_src_length = args.max_seq_length - 2 - args.max_tgt_length

        with open(args.input_file, encoding="utf-8") as fin:
            input_lines = [x.strip() for x in fin.readlines()]
            if args.subset > 0:
                logger.info("Decoding subset: %d", args.subset)
                input_lines = input_lines[:args.subset]
        data_tokenizer = WhitespaceTokenizer(
        ) if args.tokenized_input else tokenizer
        input_lines = [
            data_tokenizer.tokenize(x)[:max_src_length] for x in input_lines
        ]
        input_lines = sorted(list(enumerate(input_lines)),
                             key=lambda x: -len(x[1]))
        output_lines = [""] * len(input_lines)
        score_trace_list = [None] * len(input_lines)
        total_batch = math.ceil(len(input_lines) / args.batch_size)

        with tqdm(total=total_batch) as pbar:
            while next_i < len(input_lines):
                _chunk = input_lines[next_i:next_i + args.batch_size]
                buf_id = [x[0] for x in _chunk]
                buf = [x[1] for x in _chunk]
                next_i += args.batch_size
                max_a_len = max([len(x) for x in buf])
                instances = []
                for instance in [(x, max_a_len) for x in buf]:
                    for proc in bi_uni_pipeline:
                        instances.append(proc(instance))
                with torch.no_grad():
                    batch = seq2seq_loader.batch_list_to_batch_tensors(
                        instances)
                    batch = [
                        t.to(device) if t is not None else None for t in batch
                    ]
                    input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
                    traces = model(input_ids,
                                   token_type_ids,
                                   position_ids,
                                   input_mask,
                                   task_idx=task_idx,
                                   mask_qkv=mask_qkv,
                                   bleu=args.bleu)
                    if args.beam_size > 1:
                        traces = {k: v.tolist() for k, v in traces.items()}
                        output_ids = traces['pred_seq']
                    else:
                        output_ids = traces.tolist()
                    for i in range(len(buf)):
                        w_ids = output_ids[i]
                        output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                        output_tokens = []
                        for t in output_buf:
                            if t in ("[SEP]", "[PAD]"):
                                break
                            output_tokens.append(t)
                        output_sequence = ' '.join(detokenize(output_tokens))
                        output_lines[buf_id[i]] = output_sequence
                        if args.need_score_traces:
                            score_trace_list[buf_id[i]] = {
                                'scores': traces['scores'][i],
                                'wids': traces['wids'][i],
                                'ptrs': traces['ptrs'][i]
                            }
                pbar.update(1)
        if args.output_file:
            fn_out = args.output_file
        else:
            fn_out = model_recover_path + '.' + args.split
        with open(fn_out, "w", encoding="utf-8") as fout:
            for l in output_lines:
                fout.write(l)
                fout.write("\n")

        if args.need_score_traces:
            with open(fn_out + ".trace.pickle", "wb") as fout_trace:
                pickle.dump({
                    "version": 0.0,
                    "num_samples": len(input_lines)
                }, fout_trace)
                for x in score_trace_list:
                    pickle.dump(x, fout_trace)
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--input_file", default=None, type=str, required=True)
    parser.add_argument("--output_file", default=None, type=str, required=True)
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")

    ## Other parameters
    parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
    parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
    parser.add_argument("--max_seq_length", default=128, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
                            "than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help = "local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1)))

    layer_indexes = [int(x) for x in args.layers.split(",")]

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    examples = read_examples(args.input_file)

    features = convert_examples_to_features(
        examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer)

    unique_id_to_feature = {}
    for feature in features:
        unique_id_to_feature[feature.unique_id] = feature

    model = BertModel.from_pretrained(args.bert_model)
    model.to(device)

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)

    eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index)
    if args.local_rank == -1:
        eval_sampler = SequentialSampler(eval_data)
    else:
        eval_sampler = DistributedSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)

    model.eval()
    with open(args.output_file, "w", encoding='utf-8') as writer:
        for input_ids, input_mask, example_indices in eval_dataloader:
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)

            all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask)
            all_encoder_layers = all_encoder_layers

            for b, example_index in enumerate(example_indices):
                feature = features[example_index.item()]
                unique_id = int(feature.unique_id)
                # feature = unique_id_to_feature[unique_id]
                output_json = collections.OrderedDict()
                output_json["linex_index"] = unique_id
                all_out_features = []
                for (i, token) in enumerate(feature.tokens):
                    all_layers = []
                    for (j, layer_index) in enumerate(layer_indexes):
                        layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy()
                        layer_output = layer_output[b]
                        layers = collections.OrderedDict()
                        layers["index"] = layer_index
                        layers["values"] = [
                            round(x.item(), 6) for x in layer_output[i]
                        ]
                        all_layers.append(layers)
                    out_features = collections.OrderedDict()
                    out_features["token"] = token
                    out_features["layers"] = all_layers
                    all_out_features.append(out_features)
                output_json["features"] = all_out_features
                writer.write(json.dumps(output_json) + "\n")
Beispiel #29
0
def main():
    # args = parse_arguments()
    # del args.local_rank
    # print(args)
    # args_to_yaml(args, 'config_finetune_train_glue_mrpc.yaml')
    # exit(0)

    config_yaml, local_rank = parse_my_arguments()
    args = args_from_yaml(config_yaml)
    args.local_rank = local_rank
    """ Experiment Setup """

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        print(
            "WARNING: Output directory ({}) already exists and is not empty.".
            format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
    }

    num_labels_task = {
        "cola": 2,
        "mnli": 3,
        "mrpc": 2,
    }

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % task_name)

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )
    """ Prepare Model """

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(
            args.local_rank))
    model = BertForSequenceClassification.from_pretrained(
        args.bert_model, cache_dir=cache_dir, num_labels=num_labels)
    state_dict = torch.load(args.init_checkpoint, map_location='cpu')
    state_dict = state_dict.get(
        'model', state_dict
    )  # in a full checkpoint weights are saved in state_dict['model']
    model.load_state_dict(state_dict, strict=False)

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    plain_model = getattr(model, 'module', model)

    with open(args.sparsity_config, 'r') as f:
        raw_dict = yaml.load(f, Loader=yaml.SafeLoader)
        masks = dict.fromkeys(raw_dict['prune_ratios'].keys())
        for param_name in list(masks.keys()):
            if get_parameter_by_name(plain_model, param_name) is None:
                print(f'[WARNING] Cannot find {param_name}')
                del masks[param_name]

    for param_name in masks:
        param = get_parameter_by_name(plain_model, param_name)
        non_zero_mask = torch.ne(param, 0).to(param.dtype)
        masks[param_name] = non_zero_mask
    """ Prepare Optimizer"""

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.fp16_utils.fp16_optimizer import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        """ Prepare Dataset """

        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        """ Training Loop """

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                if args.max_steps > 0 and global_step > args.max_steps:
                    break
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                    plain_model = getattr(model, 'module', model)
                    for param_name, mask in masks.items():
                        get_parameter_by_name(plain_model,
                                              param_name).data *= mask
    """ Load Model for Evaluation """

    if args.do_train:
        # Save a trained model and the associated configuration
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

        if is_main_process(
        ):  # only the main process should save the trained model
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            torch.save(model_to_save.state_dict(), output_model_file)
            with open(output_config_file, 'w') as f:
                f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
    else:
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model, num_labels=num_labels)
        state_dict = torch.load(args.init_checkpoint, map_location='cpu')
        state_dict = state_dict.get('model', state_dict)
        model.load_state_dict(state_dict, strict=False)
    model.to(device)
    """ Run Evaluation """

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss / nb_tr_steps if args.do_train else None
        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'global_step': global_step,
            'loss': loss
        }

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                             "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "commonsenseqa": CommonsenseQaProcessor,
    }

    num_labels_task = {
        "cola": 2,
        "mnli": 3,
        "mrpc": 2,
        "commonsenseqa":4,
    }

    # if args.local_rank == -1 or args.no_cuda:
    #     device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    #     n_gpu = torch.cuda.device_count()
    # else:
    #     torch.cuda.set_device(args.local_rank)
    #     device = torch.device("cuda", args.local_rank)
    #     n_gpu = 1
    #     # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    #     torch.distributed.init_process_group(backend='nccl')

    device = "cuda:2"
    n_gpu = 1

    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))
    print("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    print("current task is " + str(task_name))

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    # model = BertForSequenceClassification.from_pretrained(args.bert_model,
    #                                                       cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(
    #                                                           args.local_rank),
    #                                                       num_labels=num_labels)
    model = BertForMultipleChoice.from_pretrained(args.bert_model,
                                                  cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.
                                                  format(args.local_rank),
                                                  num_choices=num_labels)

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0

    best_eval_accuracy = 0.0

    if args.do_train:
        train_features = convert_examples_to_features_mc(
            train_examples, label_list, args.max_seq_length, tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()

        # Save a trained model
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")

        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(global_step / t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                eval_examples = processor.get_dev_examples(args.data_dir)
                eval_features = convert_examples_to_features_mc(
                    eval_examples, label_list, args.max_seq_length, tokenizer)
                logger.info("***** Running evaluation *****")
                logger.info("  Num examples = %d", len(eval_examples))
                logger.info("  Batch size = %d", args.eval_batch_size)
                all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
                all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
                all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
                all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
                eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
                # Run prediction for full data
                eval_sampler = SequentialSampler(eval_data)
                eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

                model.eval()
                eval_loss, eval_accuracy = 0, 0
                nb_eval_steps, nb_eval_examples = 0, 0

                for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
                    input_ids = input_ids.to(device)
                    input_mask = input_mask.to(device)
                    segment_ids = segment_ids.to(device)
                    label_ids = label_ids.to(device)

                    with torch.no_grad():
                        tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
                        # logits = model(input_ids, segment_ids, input_mask)

                    logits = logits.detach().cpu().numpy()
                    label_ids = label_ids.to('cpu').numpy()
                    tmp_eval_accuracy = accuracy(logits, label_ids)

                    eval_loss += tmp_eval_loss.mean().item()
                    eval_accuracy += tmp_eval_accuracy

                    nb_eval_examples += input_ids.size(0)
                    nb_eval_steps += 1

                eval_accuracy = eval_accuracy / nb_eval_examples
                print("the current eval accuracy is: " + str(eval_accuracy))
                if eval_accuracy > best_eval_accuracy:
                    best_eval_accuracy = eval_accuracy

                    if args.do_train:
                        torch.save(model_to_save.state_dict(), output_model_file)

                model.train()

    # # Save a trained model
    # model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
    # output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    # if args.do_train:
    #     torch.save(model_to_save.state_dict(), output_model_file)

    # Load a trained model that you have fine-tuned
    model_state_dict = torch.load(output_model_file)
    model = BertForMultipleChoice.from_pretrained(args.bert_model,
                                                  state_dict=model_state_dict,
                                                  num_choices=num_labels)
    model.to(device)

    if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features_mc(
            eval_examples, label_list, args.max_seq_length, tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0

        all_pred_labels = []
        all_anno_labels = []
        all_logits = []

        for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
                # logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()

            output_labels = np.argmax(logits, axis=1)
            all_pred_labels.extend(output_labels.tolist())
            all_logits.extend(list(logits))
            all_anno_labels.extend(list(label_ids))

            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss / nb_tr_steps if args.do_train else None
        result = {'eval_loss': eval_loss,
                  'eval_accuracy': eval_accuracy,
                  'best_eval_accuracy': best_eval_accuracy,
                  'global_step': global_step,
                  'loss': loss}

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
                for i in range(len(all_pred_labels)):
                    writer.write(str(i) + "\t" + str(all_anno_labels[i]) + "\t" +
                                 str(all_pred_labels[i]) + "\t" + str(all_logits[i]) + "\n")