def test_tokenization_roberta(self):
        # Given
        self.base_tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True,
                                                               cache_dir=self.test_dir)
        self.rust_tokenizer = PyRobertaTokenizer(
            get_from_cache(self.base_tokenizer.pretrained_vocab_files_map['vocab_file']['roberta-base']),
            get_from_cache(self.base_tokenizer.pretrained_vocab_files_map['merges_file']['roberta-base']),
            do_lower_case=True
        )
        output_baseline = []
        for example in self.examples:
            output_baseline.append(self.base_tokenizer.encode_plus(example.text_a,
                                                                   add_special_tokens=True,
                                                                   return_overflowing_tokens=True,
                                                                   return_special_tokens_mask=True,
                                                                   max_length=128))

        # When
        output_rust = self.rust_tokenizer.encode_list([example.text_a for example in self.examples],
                                                      max_len=128,
                                                      truncation_strategy='longest_first',
                                                      stride=0)

        # Then
        for idx, (rust, baseline) in enumerate(zip(output_rust, output_baseline)):
            assert rust.token_ids == baseline[
                'input_ids'], f'Difference in tokenization for {self.rust_tokenizer.__class__}: \n ' \
                              f'Sentence a: {self.examples[idx].text_a} \n' \
                              f'Sentence b: {self.examples[idx].text_b} \n' \
                              f'Token mismatch: {self.get_token_diff(rust.token_ids, baseline["input_ids"])} \n' \
                              f'Rust: {rust.token_ids} \n' \
                              f' Python {baseline["input_ids"]}'
            assert (rust.special_tokens_mask == baseline['special_tokens_mask'])
Exemplo n.º 2
0
    def load(cls,
             pretrained_model_name_or_path,
             tokenizer_class=None,
             **kwargs):
        """
        Enables loading of different Tokenizer classes with a uniform interface. Either infer the class from
        `pretrained_model_name_or_path` or define it manually via `tokenizer_class`.

        :param pretrained_model_name_or_path:  The path of the saved pretrained model or its name (e.g. `bert-base-uncased`)
        :type pretrained_model_name_or_path: str
        :param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`)
        :type tokenizer_class: str
        :param kwargs:
        :return: Tokenizer
        """

        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
        # guess tokenizer type from name
        if tokenizer_class is None:
            if "albert" in pretrained_model_name_or_path.lower():
                tokenizer_class = "AlbertTokenizer"
            elif "xlm-roberta" in pretrained_model_name_or_path.lower():
                tokenizer_class = "XLMRobertaTokenizer"
            elif "roberta" in pretrained_model_name_or_path.lower():
                tokenizer_class = "RobertaTokenizer"
            elif "distilbert" in pretrained_model_name_or_path.lower():
                tokenizer_class = "DistilBertTokenizer"
            elif "bert" in pretrained_model_name_or_path.lower():
                tokenizer_class = "BertTokenizer"
            elif "xlnet" in pretrained_model_name_or_path.lower():
                tokenizer_class = "XLNetTokenizer"
            else:
                raise ValueError(
                    f"Could not infer tokenizer_type from name '{pretrained_model_name_or_path}'. Set arg `tokenizer_type` in Tokenizer.load() to one of: 'bert', 'roberta', 'xlnet' "
                )
            logger.info(f"Loading tokenizer of type '{tokenizer_class}'")
        # return appropriate tokenizer object
        if tokenizer_class == "AlbertTokenizer":
            ret = AlbertTokenizer.from_pretrained(
                pretrained_model_name_or_path, keep_accents=True, **kwargs)
        elif tokenizer_class == "XLMRobertaTokenizer":
            ret = XLMRobertaTokenizer.from_pretrained(
                pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "RobertaTokenizer":
            ret = RobertaTokenizer.from_pretrained(
                pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "DistilBertTokenizer":
            ret = DistilBertTokenizer.from_pretrained(
                pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "BertTokenizer":
            ret = BertTokenizer.from_pretrained(pretrained_model_name_or_path,
                                                **kwargs)
        elif tokenizer_class == "XLNetTokenizer":
            ret = XLNetTokenizer.from_pretrained(pretrained_model_name_or_path,
                                                 keep_accents=True,
                                                 **kwargs)
        if ret is None:
            raise Exception("Unable to load tokenizer")
        else:
            return ret
Exemplo n.º 3
0
 def __init__(self, args, device='cpu'):
     print(args.bert_model)
     self.tokenizer = RobertaTokenizer.from_pretrained(args.bert_model)
     self.data_dir = args.data_dir
     file_list = get_json_file_list(args.data_dir)
     self.data = []
     #max_article_len = 0
     for file_name in file_list:
         data = json.loads(open(file_name, 'r').read())
         data['high'] = 0
         if ('high' in file_name):
             data['high'] = 1
         self.data.append(data)
         #max_article_len = max(max_article_len, len(nltk.word_tokenize(data['article'])))
     self.data_objs = []
     high_cnt = 0
     middle_cnt = 0
     for sample in self.data:
         high_cnt += sample['high']
         middle_cnt += (1 - sample['high'])
         self.data_objs += self._create_sample(sample)
         #break
     print('high school sample:', high_cnt)
     print('middle school sample:', middle_cnt)
     for i in range(len(self.data_objs)):
         self.data_objs[i].convert_tokens_to_ids(self.tokenizer)
         #break
     torch.save(self.data_objs, args.save_name)
    def test_tokenization_roberta(self):
        # Given
        self.base_tokenizer = RobertaTokenizer.from_pretrained(
            'roberta-base', do_lower_case=True, cache_dir=self.test_dir)
        self.rust_tokenizer = PyRobertaTokenizer(
            get_from_cache(
                self.base_tokenizer.pretrained_vocab_files_map['vocab_file']
                ['roberta-base']),
            get_from_cache(
                self.base_tokenizer.pretrained_vocab_files_map['merges_file']
                ['roberta-base']))
        output_baseline = []
        for example in self.examples:
            output_baseline.append(
                self.base_tokenizer.encode_plus(
                    example.text_a,
                    add_special_tokens=True,
                    return_overflowing_tokens=True,
                    return_special_tokens_mask=True,
                    max_length=128))

        # When
        output_rust = self.rust_tokenizer.encode_list(
            [example.text_a for example in self.examples],
            max_len=128,
            truncation_strategy='longest_first',
            stride=0)

        # Then
        for rust, baseline in zip(output_rust, output_baseline):
            assert (rust.token_ids == baseline['input_ids'])
            assert (rust.segment_ids == baseline['token_type_ids'])
            assert (
                rust.special_tokens_mask == baseline['special_tokens_mask'])
Exemplo n.º 5
0
def get_roberta_tokenizer(pretrained_cfg_name: str,
                          do_lower_case: bool = True):
    # still uses HF code for tokenizer since they are the same
    if "camembert" in pretrained_cfg_name:
        return CamembertTokenizer.from_pretrained(pretrained_cfg_name,
                                                  do_lower_case=do_lower_case)
    return RobertaTokenizer.from_pretrained(pretrained_cfg_name,
                                            do_lower_case=do_lower_case)
Exemplo n.º 6
0
    def test_sequence_builders(self):
        tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

        text = tokenizer.encode("sequence builders")
        text_2 = tokenizer.encode("multi-sequence build")

        encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
        encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True)

        encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
        encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)

        assert encoded_sentence == encoded_text_from_decode
        assert encoded_pair == encoded_pair_from_decode
Exemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--round_name",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    round_name_2_rounds = {
        'base': ['base', 'ood'],
        'r1': ['base', 'n1', 'ood'],
        'r2': ['base', 'n1', 'n2', 'ood'],
        'r3': ['base', 'n1', 'n2', 'n3', 'ood'],
        'r4': ['base', 'n1', 'n2', 'n3', 'n4', 'ood'],
        'r5': ['base', 'n1', 'n2', 'n3', 'n4', 'n5', 'ood']
    }

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    banking77_class_list, ood_class_set, class_2_split = load_class_names()

    round_list = round_name_2_rounds.get(args.round_name)
    train_examples, base_class_list = processor.load_train(
        ['base'])  #train on base only
    '''train the first stage'''
    model = RobertaForSequenceClassification(len(base_class_list))
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    train_dataloader = examples_to_features(train_examples,
                                            base_class_list,
                                            args,
                                            tokenizer,
                                            args.train_batch_size,
                                            "classification",
                                            dataloader_mode='random')
    mean_loss = 0.0
    count = 0
    for _ in trange(int(args.num_train_epochs), desc="Stage1Epoch"):
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch

            logits = model(input_ids, input_mask, output_rep=False)
            loss_fct = CrossEntropyLoss()

            loss = loss_fct(logits.view(-1, len(base_class_list)),
                            label_ids.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            mean_loss += loss.item()
            count += 1
            # if count % 50 == 0:
            #     print('mean loss:', mean_loss/count)
    print('stage 1, train supervised classification on base is over.')
    '''now, train the second stage'''
    model_stage_2 = ModelStageTwo(len(base_class_list), model)
    model_stage_2.to(device)

    param_optimizer = list(model_stage_2.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer_stage_2 = AdamW(optimizer_grouped_parameters,
                              lr=args.learning_rate)
    mean_loss = 0.0
    count = 0
    best_threshold = 0.0
    for _ in trange(int(args.num_train_epochs), desc="Stage2Epoch"):
        '''first, select some base classes as fake novel classes'''
        fake_novel_size = 15
        fake_novel_support_size = 5
        '''for convenience, we keep shuffle the base classes, select the last 5 as fake novel'''
        original_base_class_idlist = list(range(len(base_class_list)))
        # random.shuffle(original_base_class_idlist)
        shuffled_base_class_list = [
            base_class_list[idd] for idd in original_base_class_idlist
        ]
        fake_novel_classlist = shuffled_base_class_list[-fake_novel_size:]
        '''load their support examples'''
        base_support_examples = processor.load_base_support_examples(
            fake_novel_classlist, fake_novel_support_size)
        base_support_dataloader = examples_to_features(
            base_support_examples,
            fake_novel_classlist,
            args,
            tokenizer,
            fake_novel_support_size,
            "classification",
            dataloader_mode='sequential')

        novel_class_support_reps = []
        for _, batch in enumerate(base_support_dataloader):
            input_ids, input_mask, segment_ids, label_ids = batch
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            model.eval()
            with torch.no_grad():
                support_rep_for_novel_class = model(input_ids,
                                                    input_mask,
                                                    output_rep=True)
            novel_class_support_reps.append(support_rep_for_novel_class)
        assert len(novel_class_support_reps) == fake_novel_size
        print('Extracting support reps for fake novel is over.')
        '''retrain on query set to optimize the weight generator'''
        train_dataloader = examples_to_features(train_examples,
                                                shuffled_base_class_list,
                                                args,
                                                tokenizer,
                                                args.train_batch_size,
                                                "classification",
                                                dataloader_mode='random')
        best_threshold_list = []
        for _ in range(10):  #repeat 10 times is important
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model_stage_2.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                logits = model_stage_2(
                    input_ids,
                    input_mask,
                    model,
                    novel_class_support_reps=novel_class_support_reps,
                    fake_novel_size=fake_novel_size,
                    base_class_mapping=original_base_class_idlist)
                # print('logits:', logits)
                loss_fct = CrossEntropyLoss()

                loss = loss_fct(logits.view(-1, len(base_class_list)),
                                label_ids.view(-1))
                loss.backward()
                optimizer_stage_2.step()
                optimizer_stage_2.zero_grad()
                mean_loss += loss.item()
                count += 1
                if count % 50 == 0:
                    print('mean loss:', mean_loss / count)
                scores_for_positive = logits[torch.arange(logits.shape[0]),
                                             label_ids.view(-1)].mean()
                best_threshold_list.append(scores_for_positive.item())

        best_threshold = sum(best_threshold_list) / len(best_threshold_list)

    print('stage 2 training over')
    '''
    start testing
    '''
    '''first, get reps for all base+novel classes'''
    '''support for all seen classes'''
    class_2_support_examples, seen_class_list = processor.load_support_all_rounds(
        round_list[:-1])  #no support set for ood
    assert seen_class_list[:len(base_class_list)] == base_class_list
    seen_class_list_size = len(seen_class_list)
    support_example_lists = [
        class_2_support_examples.get(seen_class)
        for seen_class in seen_class_list if seen_class not in base_class_list
    ]

    novel_class_support_reps = []
    for eval_support_examples_per_class in support_example_lists:
        support_dataloader = examples_to_features(
            eval_support_examples_per_class,
            seen_class_list,
            args,
            tokenizer,
            5,
            "classification",
            dataloader_mode='random')
        single_class_support_reps = []
        for _, batch in enumerate(support_dataloader):
            input_ids, input_mask, segment_ids, label_ids = batch
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            model.eval()
            with torch.no_grad():
                support_rep_for_novel_class = model(input_ids,
                                                    input_mask,
                                                    output_rep=True)
            single_class_support_reps.append(support_rep_for_novel_class)
        single_class_support_reps = torch.cat(single_class_support_reps,
                                              axis=0)
        novel_class_support_reps.append(single_class_support_reps)
    print('len(novel_class_support_reps):', len(novel_class_support_reps))
    print('len(base_class_list):', len(base_class_list))
    print('len(seen_class_list):', len(seen_class_list))
    assert len(novel_class_support_reps) + len(base_class_list) == len(
        seen_class_list)
    print('Extracting support reps for all  novel is over.')
    test_examples = processor.load_dev_or_test(round_list, 'test')
    test_class_list = seen_class_list + list(ood_class_set)
    print('test_class_list:', len(test_class_list))
    print('best_threshold:', best_threshold)
    test_split_list = []
    for test_class_i in test_class_list:
        test_split_list.append(class_2_split.get(test_class_i))
    test_dataloader = examples_to_features(test_examples,
                                           test_class_list,
                                           args,
                                           tokenizer,
                                           args.eval_batch_size,
                                           "classification",
                                           dataloader_mode='sequential')
    '''test on test batch '''
    preds = []
    gold_label_ids = []
    for input_ids, input_mask, segment_ids, label_ids in test_dataloader:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        label_ids = label_ids.to(device)
        gold_label_ids += list(label_ids.detach().cpu().numpy())
        model_stage_2.eval()
        with torch.no_grad():
            logits = model_stage_2(
                input_ids,
                input_mask,
                model,
                novel_class_support_reps=novel_class_support_reps,
                fake_novel_size=None,
                base_class_mapping=None)
        # print('test logits:', logits)
        if len(preds) == 0:
            preds.append(logits.detach().cpu().numpy())
        else:
            preds[0] = np.append(preds[0],
                                 logits.detach().cpu().numpy(),
                                 axis=0)

    preds = preds[0]

    pred_probs = preds  #softmax(preds,axis=1)
    pred_label_ids_raw = list(np.argmax(pred_probs, axis=1))
    pred_max_prob = list(np.amax(pred_probs, axis=1))

    pred_label_ids = []
    for i, pred_max_prob_i in enumerate(pred_max_prob):
        if pred_max_prob_i < best_threshold:
            pred_label_ids.append(
                seen_class_list_size)  #seen_class_list_size means ood
        else:
            pred_label_ids.append(pred_label_ids_raw[i])

    assert len(pred_label_ids) == len(gold_label_ids)
    acc_each_round = []
    for round_name_id in round_list:
        #base, n1, n2, ood
        round_size = 0
        rount_hit = 0
        if round_name_id != 'ood':
            for ii, gold_label_id in enumerate(gold_label_ids):
                if test_split_list[gold_label_id] == round_name_id:
                    round_size += 1
                    # print('gold_label_id:', gold_label_id, 'pred_label_ids[ii]:', pred_label_ids[ii])
                    if gold_label_id == pred_label_ids[ii]:
                        rount_hit += 1
            acc_i = rount_hit / round_size
            acc_each_round.append(acc_i)
        else:
            '''ood f1'''
            gold_binary_list = []
            pred_binary_list = []
            for ii, gold_label_id in enumerate(gold_label_ids):
                # print('gold_label_id:', gold_label_id, 'pred_label_ids[ii]:', pred_label_ids[ii])
                gold_binary_list.append(1 if test_split_list[gold_label_id] ==
                                        round_name_id else 0)
                pred_binary_list.append(1 if pred_label_ids[ii] ==
                                        seen_class_list_size else 0)
            overlap = 0
            for i in range(len(gold_binary_list)):
                if gold_binary_list[i] == 1 and pred_binary_list[i] == 1:
                    overlap += 1
            recall = overlap / (1e-6 + sum(gold_binary_list))
            precision = overlap / (1e-6 + sum(pred_binary_list))

            acc_i = 2 * recall * precision / (1e-6 + recall + precision)
            acc_each_round.append(acc_i)

    print('\n\t\t test_acc:', acc_each_round)
    final_test_performance = acc_each_round

    print('final_test_performance:', final_test_performance)
Exemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser()


    ## Other parameters
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")

    parser.add_argument('--kshot',
                        type=int,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")


    args = parser.parse_args()



    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


    mctest_path = '/export/home/Dataset/MCTest/Statements/'
    target_kshot_entail_examples, target_kshot_nonentail_examples = get_MCTest_train(mctest_path+'mc500.train.statements.pairs', args.kshot) #train_pu_half_v1.txt
    target_dev_examples, target_test_examples = get_MCTest_dev_and_test(mctest_path+'mc500.dev.statements.pairs', mctest_path+'mc500.test.statements.pairs')


    source_kshot_entail, source_kshot_neural, source_kshot_contra, source_remaining_examples = get_MNLI_train('/export/home/Dataset/glue_data/MNLI/train.tsv', args.kshot)
    source_examples = source_kshot_entail+ source_kshot_neural+ source_kshot_contra+ source_remaining_examples
    target_label_list = ["ENTAILMENT", "UNKNOWN"]
    source_label_list = ["entailment", "neutral", "contradiction"]
    source_num_labels = len(source_label_list)
    target_num_labels = len(target_label_list)
    print('training size:', len(source_examples), 'dev size:', len(target_dev_examples), 'test size:', len(target_test_examples))

    num_train_optimization_steps = None
    num_train_optimization_steps = int(
        len(source_remaining_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

    roberta_model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(pretrain_model_dir, do_lower_case=args.do_lower_case)
    roberta_model.load_state_dict(torch.load('/export/home/Dataset/BERT_pretrained_mine/MNLI_pretrained/_acc_0.9040886899918633.pt'),strict=False)
    roberta_model.to(device)
    roberta_model.eval()

    protonet = PrototypeNet(bert_hidden_dim)
    protonet.to(device)

    param_optimizer = list(protonet.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    optimizer = AdamW(optimizer_grouped_parameters,
                             lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0

    retrieve_batch_size = 5

    source_kshot_entail_dataloader = examples_to_features(source_kshot_entail, source_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    source_kshot_neural_dataloader = examples_to_features(source_kshot_neural, source_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    source_kshot_contra_dataloader = examples_to_features(source_kshot_contra, source_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    source_remain_ex_dataloader = examples_to_features(source_remaining_examples, source_label_list, args, tokenizer, args.train_batch_size, "classification", dataloader_mode='random')

    target_kshot_entail_dataloader = examples_to_features(target_kshot_entail_examples, target_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    target_kshot_nonentail_dataloader = examples_to_features(target_kshot_nonentail_examples, target_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    target_dev_dataloader = examples_to_features(target_dev_examples, target_label_list, args, tokenizer, args.eval_batch_size, "classification", dataloader_mode='sequential')
    target_test_dataloader = examples_to_features(target_test_examples, target_label_list, args, tokenizer, args.eval_batch_size, "classification", dataloader_mode='sequential')

    '''starting to train'''
    iter_co = 0
    final_test_performance = 0.0
    for _ in trange(int(args.num_train_epochs), desc="Epoch"):
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(tqdm(source_remain_ex_dataloader, desc="Iteration")):
            protonet.train()
            batch = tuple(t.to(device) for t in batch)
            _, input_ids, input_mask, segment_ids, label_ids_batch = batch

            roberta_model.eval()
            with torch.no_grad():
                last_hidden_batch, _ = roberta_model(input_ids, input_mask)
            '''
            retrieve rep for support examples
            '''
            kshot_entail_reps = []
            for entail_batch in source_kshot_entail_dataloader:
                entail_batch = tuple(t.to(device) for t in entail_batch)
                _, input_ids, input_mask, segment_ids, label_ids = entail_batch
                roberta_model.eval()
                with torch.no_grad():
                    last_hidden_entail, _ = roberta_model(input_ids, input_mask)
                kshot_entail_reps.append(last_hidden_entail)
            kshot_entail_rep = torch.mean(torch.cat(kshot_entail_reps, dim=0), dim=0, keepdim=True)
            kshot_neural_reps = []
            for neural_batch in source_kshot_neural_dataloader:
                neural_batch = tuple(t.to(device) for t in neural_batch)
                _, input_ids, input_mask, segment_ids, label_ids = neural_batch
                roberta_model.eval()
                with torch.no_grad():
                    last_hidden_neural, _ = roberta_model(input_ids, input_mask)
                kshot_neural_reps.append(last_hidden_neural)
            kshot_neural_rep = torch.mean(torch.cat(kshot_neural_reps, dim=0), dim=0, keepdim=True)
            kshot_contra_reps = []
            for contra_batch in source_kshot_contra_dataloader:
                contra_batch = tuple(t.to(device) for t in contra_batch)
                _, input_ids, input_mask, segment_ids, label_ids = contra_batch
                roberta_model.eval()
                with torch.no_grad():
                    last_hidden_contra, _ = roberta_model(input_ids, input_mask)
                kshot_contra_reps.append(last_hidden_contra)
            kshot_contra_rep = torch.mean(torch.cat(kshot_contra_reps, dim=0), dim=0, keepdim=True)

            class_prototype_reps = torch.cat([kshot_entail_rep, kshot_neural_rep, kshot_contra_rep], dim=0) #(3, hidden)

            '''forward to model'''
            batch_logits = protonet(class_prototype_reps, last_hidden_batch)

            loss_fct = CrossEntropyLoss()

            loss = loss_fct(batch_logits.view(-1, source_num_labels), label_ids_batch.view(-1))

            if n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            iter_co+=1
            # if iter_co %20==0:
            if iter_co % len(source_remain_ex_dataloader)==0:
                '''
                start evaluate on dev set after this epoch
                '''
                protonet.eval()
                '''first get representations for support examples'''
                kshot_entail_reps = []
                for entail_batch in target_kshot_entail_dataloader:
                    entail_batch = tuple(t.to(device) for t in entail_batch)
                    _, input_ids, input_mask, segment_ids, label_ids = entail_batch
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_entail, _ = roberta_model(input_ids, input_mask)
                    kshot_entail_reps.append(last_hidden_entail)
                kshot_entail_rep = torch.mean(torch.cat(kshot_entail_reps, dim=0), dim=0, keepdim=True)
                kshot_nonentail_reps = []
                for nonentail_batch in target_kshot_nonentail_dataloader:
                    nonentail_batch = tuple(t.to(device) for t in nonentail_batch)
                    _, input_ids, input_mask, segment_ids, label_ids = nonentail_batch
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_nonentail, _ = roberta_model(input_ids, input_mask)
                    kshot_nonentail_reps.append(last_hidden_nonentail)
                kshot_nonentail_rep = torch.mean(torch.cat(kshot_nonentail_reps, dim=0), dim=0, keepdim=True)
                target_class_prototype_reps = torch.cat([kshot_entail_rep, kshot_nonentail_rep], dim=0) #(2, hidden)

                for idd, dev_or_test_dataloader in enumerate([target_dev_dataloader, target_test_dataloader]):

                    if idd == 0:
                        logger.info("***** Running dev *****")
                        logger.info("  Num examples = %d", len(target_dev_examples))
                    else:
                        logger.info("***** Running test *****")
                        logger.info("  Num examples = %d", len(target_test_examples))


                    eval_loss = 0
                    nb_eval_steps = 0
                    preds = []
                    gold_label_ids = []
                    gold_pair_ids = []
                    for input_pair_ids, input_ids, input_mask, segment_ids, label_ids in dev_or_test_dataloader:
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        gold_pair_ids+= list(input_pair_ids.numpy())
                        label_ids = label_ids.to(device)
                        gold_label_ids+=list(label_ids.detach().cpu().numpy())
                        roberta_model.eval()
                        with torch.no_grad():
                            last_hidden_target_batch, _ = roberta_model(input_ids, input_mask)

                        with torch.no_grad():
                            logits = protonet(target_class_prototype_reps, last_hidden_target_batch)
                        if len(preds) == 0:
                            preds.append(logits.detach().cpu().numpy())
                        else:
                            preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)
                    preds = preds[0]
                    pred_probs = list(softmax(preds,axis=1)[:,0]) #entail prob

                    assert len(gold_pair_ids) == len(pred_probs)
                    assert len(gold_pair_ids) == len(gold_label_ids)

                    pairID_2_predgoldlist = {}
                    for pair_id, prob, gold_id in zip(gold_pair_ids, pred_probs, gold_label_ids):
                        predgoldlist = pairID_2_predgoldlist.get(pair_id)
                        if predgoldlist is None:
                            predgoldlist = []
                        predgoldlist.append((prob, gold_id))
                        pairID_2_predgoldlist[pair_id] = predgoldlist
                    total_size = len(pairID_2_predgoldlist)
                    hit_size = 0
                    for pair_id, predgoldlist in pairID_2_predgoldlist.items():
                        predgoldlist.sort(key=lambda x:x[0]) #sort by prob
                        assert len(predgoldlist) == 4
                        if predgoldlist[-1][1] == 0:
                            hit_size+=1
                    test_acc= hit_size/total_size

                    if idd == 0: # this is dev
                        if test_acc > max_dev_acc:
                            max_dev_acc = test_acc
                            print('\ndev acc:', test_acc, ' max_dev_acc:', max_dev_acc, '\n')

                        else:
                            print('\ndev acc:', test_acc, ' max_dev_acc:', max_dev_acc, '\n')
                            break
                    else: # this is test
                        if test_acc > max_test_acc:
                            max_test_acc = test_acc

                        final_test_performance = test_acc
                        print('\n\t\t test acc:', test_acc, ' max_test_acc:', max_test_acc, '\n')

    print('final_test_performance:', final_test_performance)
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default='/hdd/lujunyu/dataset/multi_turn_corpus/ubuntu/',
                        type=str,
                        required=False,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--task_name",
                        default='ubuntu',
                        type=str,
                        required=False,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default='/hdd/lujunyu/model/chatbert/check/',
                        type=str,
                        required=False,
                        help="The output directory where the model checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--data_augmentation",
                        default=False,
                        action='store_true',
                        help="Whether to use augmentation")
    parser.add_argument("--max_seq_length",
                        default=256,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        default=True,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_test",
                        default=True,
                        action='store_true',
                        help="Whether to run eval on the test set.")
    parser.add_argument("--train_batch_size",
                        default=400,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=100,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=20.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_steps",
                        default=0.0,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--weight_decay",
                        default=1e-3,
                        type=float,
                        help="weight_decay")
    parser.add_argument("--save_checkpoints_steps",
                        default=3125,
                        type=int,
                        help="How often to save the model checkpoint.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=5,
                        help="Number of updates steps to accumualte before performing a backward/update pass.")
    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    bert_config = RobertaConfig.from_pretrained('roberta-base', num_labels=2)

    if args.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format(
            args.max_seq_length, bert_config.max_position_embeddings))

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        if args.do_train:
            raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    else:
        os.makedirs(args.output_dir, exist_ok=True)

    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    if args.data_augmentation:
        train_dataset = UbuntuDatasetForRoberta(
            file_path=os.path.join(args.data_dir, "train_augment_ubuntu.txt"),
            max_seq_length=args.max_seq_length,
            tokenizer=tokenizer
        )
    else:
        train_dataset = UbuntuDatasetForRoberta(
            file_path=os.path.join(args.data_dir, "train.txt"),
            max_seq_length=args.max_seq_length,
            tokenizer=tokenizer
        )
    eval_dataset = UbuntuDatasetForRoberta(
        file_path=os.path.join(args.data_dir, "valid.txt"),  ### TODO:change
        max_seq_length=args.max_seq_length,
        tokenizer=tokenizer
    )

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size,
                                                sampler=RandomSampler(train_dataset), num_workers=8)
    eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=args.eval_batch_size,
                                                sampler=SequentialSampler(eval_dataset), num_workers=8)

    model = RobertaForSequenceClassification.from_pretrained('roberta-base',config=bert_config)
    model.to(device)

    num_train_steps = None
    if args.do_train:
        num_train_steps = int(
            len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
        # Prepare optimizer
        param_optimizer = list(model.named_parameters())
        # remove pooler, which is not used thus it produce None grad that break apex
        param_optimizer = [n for n in param_optimizer]

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
            'weight_decay': args.weight_decay}, {
            'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay': 0.0}]

        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
        scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_steps)
    else:
        optimizer = None
        scheduler = None

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.data)

    global_step = 0
    best_metric = 0.0
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, label_ids = batch
                loss, _ = model(input_ids, labels=label_ids)
                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                loss.backward()
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()    # We have accumulated enought gradients
                    scheduler.step()
                    model.zero_grad()
                    global_step += 1

                if (step + 1) % args.save_checkpoints_steps == 0:
                    model.eval()
                    f = open(os.path.join(args.output_dir, 'logits_dev.txt'), 'w')
                    eval_loss = 0
                    nb_eval_steps, nb_eval_examples = 0, 0
                    logits_all = []
                    for input_ids, label_ids in eval_dataloader:
                        input_ids = input_ids.to(device)
                        label_ids = label_ids.to(device)

                        with torch.no_grad():
                            tmp_eval_loss, logits = model(input_ids, labels=label_ids)

                        logits = logits.detach().cpu().numpy()
                        logits_all.append(logits)
                        label_ids = label_ids.cpu().numpy()

                        for logit, label in zip(logits, label_ids):
                            logit = '{},{}'.format(logit[0], logit[1])
                            f.write('_\t{}\t{}\n'.format(logit, label))

                        eval_loss += tmp_eval_loss.mean().item()

                        nb_eval_examples += input_ids.size(0)
                        nb_eval_steps += 1

                    f.close()
                    logits_all = np.concatenate(logits_all,axis=0)
                    eval_loss = eval_loss / nb_eval_steps

                    result = evaluate(os.path.join(args.output_dir, 'logits_dev.txt'))
                    result.update({'eval_loss': eval_loss})

                    output_eval_file = os.path.join(args.output_dir, "eval_results_dev.txt")
                    with open(output_eval_file, "a") as writer:
                        logger.info("***** Eval results *****")
                        for key in sorted(result.keys()):
                            logger.info("  %s = %s", key, str(result[key]))
                            writer.write("%s = %s\n" % (key, str(result[key])))

                    ### Save the best checkpoint
                    if best_metric < result['R10@1'] + result['R10@2']:
                        try:  ### Remove 'module' prefix when using DataParallel
                            state_dict = model.module.state_dict()
                        except AttributeError:
                            state_dict = model.state_dict()
                        torch.save(state_dict, os.path.join(args.output_dir, "model.pt"))
                        best_metric = result['R10@1'] + result['R10@2']
                        logger.info('Saving the best model in {}'.format(os.path.join(args.output_dir, "model.pt")))

                        ### visualize bad cases of the best model
                        logger.info('Saving Bad cases...')
                        visualize_bad_cases(
                            logits=logits_all,
                            input_file_path=os.path.join(args.data_dir, 'valid.txt'),
                            output_file_path=os.path.join(args.output_dir, 'valid_bad_cases.txt')
                        )

                    model.train()
Exemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--round_name",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=50,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    round_name_2_rounds = {
        'r1': ['n1', 'ood'],
        'r2': ['n1', 'n2', 'ood'],
        'r3': ['n1', 'n2', 'n3', 'ood'],
        'r4': ['n1', 'n2', 'n3', 'n4', 'ood'],
        'r5': ['n1', 'n2', 'n3', 'n4', 'n5', 'ood']
    }

    model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.load_state_dict(torch.load('../../data/MNLI_pretrained.pt'),
                          strict=False)
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    banking77_class_list, ood_class_set, class_2_split = load_class_names()

    round_list = round_name_2_rounds.get(args.round_name)
    '''load training in list'''
    train_examples_list, train_class_list, train_class_2_split_list, class_2_sentlist_upto_this_round = processor.load_train(
        round_list[:-1])  # no odd training examples
    assert len(train_class_list) == len(train_class_2_split_list)
    # assert len(train_class_list) ==  20+(len(round_list)-2)*10
    '''dev and test'''
    dev_examples, dev_instance_size = processor.load_dev_or_test(
        round_list, train_class_list, class_2_sentlist_upto_this_round, 'dev')
    test_examples, test_instance_size = processor.load_dev_or_test(
        round_list, train_class_list, class_2_sentlist_upto_this_round, 'test')
    print('train size:', [len(train_i) for train_i in train_examples_list],
          ' dev size:', len(dev_examples), ' test size:', len(test_examples))
    entail_class_list = ['entailment', 'non-entailment']
    eval_class_list = train_class_list + ['ood']
    test_split_list = train_class_2_split_list + ['ood']
    train_dataloader_list = []
    for train_examples in train_examples_list:
        train_dataloader = examples_to_features(train_examples,
                                                entail_class_list,
                                                eval_class_list,
                                                args,
                                                tokenizer,
                                                args.train_batch_size,
                                                "classification",
                                                dataloader_mode='random')
        train_dataloader_list.append(train_dataloader)
    dev_dataloader = examples_to_features(dev_examples,
                                          entail_class_list,
                                          eval_class_list,
                                          args,
                                          tokenizer,
                                          args.eval_batch_size,
                                          "classification",
                                          dataloader_mode='sequential')
    test_dataloader = examples_to_features(test_examples,
                                           entail_class_list,
                                           eval_class_list,
                                           args,
                                           tokenizer,
                                           args.eval_batch_size,
                                           "classification",
                                           dataloader_mode='sequential')
    '''training'''
    max_test_acc = 0.0
    max_dev_acc = 0.0
    for round_index, round in enumerate(round_list[:-1]):
        '''for the new examples in each round, train multiple epochs'''
        train_dataloader = train_dataloader_list[round_index]
        for epoch_i in range(args.num_train_epochs):
            for _, batch in enumerate(
                    tqdm(train_dataloader,
                         desc="train|" + round + '|epoch_' + str(epoch_i))):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                _, input_ids, input_mask, _, label_ids, _, _ = batch

                logits = model(input_ids, input_mask)
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, 3), label_ids.view(-1))
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
        print('\t\t round ', round, ' is over...')
    '''evaluation'''
    model.eval()
    '''test'''
    acc_each_round = []
    preds = []
    gold_guids = []
    gold_premise_ids = []
    gold_hypothesis_ids = []
    for _, batch in enumerate(tqdm(test_dataloader, desc="test")):
        guids, input_ids, input_mask, _, label_ids, premise_class_ids, hypothesis_class_id = batch
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)

        gold_guids += list(guids.detach().cpu().numpy())
        gold_premise_ids += list(premise_class_ids.detach().cpu().numpy())
        gold_hypothesis_ids += list(hypothesis_class_id.detach().cpu().numpy())

        with torch.no_grad():
            logits = model(input_ids, input_mask)
        if len(preds) == 0:
            preds.append(logits.detach().cpu().numpy())
        else:
            preds[0] = np.append(preds[0],
                                 logits.detach().cpu().numpy(),
                                 axis=0)
    preds = softmax(preds[0], axis=1)

    pred_label_3way = np.argmax(preds,
                                axis=1)  #dev_examples, 0 means "entailment"
    pred_probs = list(
        preds[:, 0])  #prob for "entailment" class: (#input, #seen_classe)
    assert len(pred_label_3way) == len(test_examples)
    assert len(pred_probs) == len(test_examples)
    assert len(gold_premise_ids) == len(test_examples)
    assert len(gold_hypothesis_ids) == len(test_examples)
    assert len(gold_guids) == len(test_examples)

    guid_2_premise_idlist = defaultdict(list)
    guid_2_hypoID_2_problist_labellist = {}
    for guid_i, threeway_i, prob_i, premise_i, hypo_i in zip(
            gold_guids, pred_label_3way, pred_probs, gold_premise_ids,
            gold_hypothesis_ids):
        guid_2_premise_idlist[guid_i].append(premise_i)
        hypoID_2_problist_labellist = guid_2_hypoID_2_problist_labellist.get(
            guid_i)
        if hypoID_2_problist_labellist is None:
            hypoID_2_problist_labellist = {}
        lists = hypoID_2_problist_labellist.get(hypo_i)
        if lists is None:
            lists = [[], []]
        lists[0].append(prob_i)
        lists[1].append(threeway_i)
        hypoID_2_problist_labellist[hypo_i] = lists
        guid_2_hypoID_2_problist_labellist[
            guid_i] = hypoID_2_problist_labellist

    pred_label_ids = []
    gold_label_ids = []
    for guid in range(test_instance_size):
        assert len(set(guid_2_premise_idlist.get(guid))) == 1
        gold_label_ids.append(guid_2_premise_idlist.get(guid)[0])
        '''infer predict label id'''
        hypoID_2_problist_labellist = guid_2_hypoID_2_problist_labellist.get(
            guid)

        final_max_mean_prob = 0.0
        final_hypo_id = -1
        for hypo_id, problist_labellist in hypoID_2_problist_labellist.items():
            problist = problist_labellist[0]
            mean_prob = np.mean(problist)
            labellist = problist_labellist[1]
            same_cluter_times = labellist.count(
                0)  #'entailment' is the first label
            same_cluter = False
            if same_cluter_times / len(labellist) > 0.5:
                same_cluter = True

            if same_cluter is True and mean_prob > final_max_mean_prob:
                final_max_mean_prob = mean_prob
                final_hypo_id = hypo_id
        if final_hypo_id != -1:  # can find a class that it belongs to
            pred_label_ids.append(final_hypo_id)
        else:
            pred_label_ids.append(len(train_class_list))

    assert len(pred_label_ids) == len(gold_label_ids)
    acc_each_round = []
    for round_name_id in round_list:
        #base, n1, n2, ood
        round_size = 0
        rount_hit = 0
        if round_name_id != 'ood':
            for ii, gold_label_id in enumerate(gold_label_ids):
                if test_split_list[gold_label_id] == round_name_id:
                    round_size += 1
                    if gold_label_id == pred_label_ids[ii]:
                        rount_hit += 1
            acc_i = rount_hit / round_size
            acc_each_round.append(acc_i)
        else:
            '''ood acc'''
            gold_binary_list = []
            pred_binary_list = []
            for ii, gold_label_id in enumerate(gold_label_ids):
                gold_binary_list.append(1 if test_split_list[gold_label_id] ==
                                        round_name_id else 0)
                pred_binary_list.append(1 if pred_label_ids[ii] ==
                                        len(train_class_list) else 0)
            overlap = 0
            for i in range(len(gold_binary_list)):
                if gold_binary_list[i] == 1 and pred_binary_list[i] == 1:
                    overlap += 1
            recall = overlap / (1e-6 + sum(gold_binary_list))
            precision = overlap / (1e-6 + sum(pred_binary_list))
            acc_i = 2 * recall * precision / (1e-6 + recall + precision)
            acc_each_round.append(acc_i)

    print('final_test_performance:', acc_each_round)
Exemplo n.º 11
0
from transformers.tokenization_roberta import RobertaTokenizer

from transformers import BertTokenizer

parser = argparse.ArgumentParser()
parser.add_argument("--input_path", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--passage_length_limit", type=int, default=463)
parser.add_argument("--question_length_limit", type=int, default=46)
parser.add_argument("--encoder", type=str, default="bert")
parser.add_argument("--mode", type=str, default='train')

args = parser.parse_args()

if args.encoder == 'roberta':
    tokenizer = RobertaTokenizer.from_pretrained(args.input_path +
                                                 "/roberta.large")
    sep = '<s>'
elif args.encoder == 'bert':
    tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
    sep = '[SEP]'
elif args.encoder == 'finbert':
    tokenizer = BertTokenizer.from_pretrained(args.input_path + "/finbert")
    sep = '[SEP]'

if args.mode == 'test':
    data_reader = TagTaTQATestReader(tokenizer,
                                     args.passage_length_limit,
                                     args.question_length_limit,
                                     sep=sep)
    data_mode = ["test", "dev"]
else:
Exemplo n.º 12
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")

    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    mctest_path = '/export/home/Dataset/MCTest/Statements/'
    # train_examples = processor.get_SciTail_as_train_k_shot(scitail_path+'scitail_1.0_train.tsv', args.kshot) #train_pu_half_v1.txt
    _, test_examples = processor.get_MCTest_dev_and_test(
        mctest_path + 'mc500.dev.statements.pairs',
        mctest_path + 'mc500.test.statements.pairs')

    label_list = ["ENTAILMENT", "UNKNOWN"]
    # train_examples = get_data_hulu_fewshot('train', 5)

    num_labels = len(label_list)
    print('num_labels:', num_labels, 'test size:', len(test_examples))

    model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.load_state_dict(torch.load(
        '/export/home/Dataset/BERT_pretrained_mine/MNLI_pretrained/_acc_0.9040886899918633.pt'
    ),
                          strict=False)
    model.to(device)
    '''load test set'''
    test_features = convert_examples_to_features(
        test_examples,
        label_list,
        args.max_seq_length,
        tokenizer,
        output_mode,
        cls_token_at_end=
        False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
        cls_token=tokenizer.cls_token,
        cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
        sep_token=tokenizer.sep_token,
        sep_token_extra=
        True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
        pad_on_left=
        False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
        pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
        pad_token_segment_id=0)  #4 if args.model_type in ['xlnet'] else 0,)

    eval_all_pair_ids = torch.tensor([f.input_pair_id for f in test_features],
                                     dtype=torch.long)
    eval_all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                      dtype=torch.long)
    eval_all_input_mask = torch.tensor([f.input_mask for f in test_features],
                                       dtype=torch.long)
    eval_all_segment_ids = torch.tensor([f.segment_ids for f in test_features],
                                        dtype=torch.long)
    eval_all_label_ids = torch.tensor([f.label_id for f in test_features],
                                      dtype=torch.long)

    eval_data = TensorDataset(eval_all_pair_ids, eval_all_input_ids,
                              eval_all_input_mask, eval_all_segment_ids,
                              eval_all_label_ids)
    eval_sampler = SequentialSampler(eval_data)
    test_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    model.eval()

    logger.info("***** Running test *****")
    logger.info("  Num examples = %d", len(test_examples))
    # logger.info("  Batch size = %d", args.eval_batch_size)

    eval_loss = 0
    nb_eval_steps = 0
    preds = []
    gold_label_ids = []
    gold_pair_ids = []
    for input_pair_ids, input_ids, input_mask, segment_ids, label_ids in test_dataloader:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        gold_pair_ids += list(input_pair_ids.numpy())
        label_ids = label_ids.to(device)
        gold_label_ids += list(label_ids.detach().cpu().numpy())

        with torch.no_grad():
            logits = model(input_ids, input_mask)
        if len(preds) == 0:
            preds.append(logits.detach().cpu().numpy())
        else:
            preds[0] = np.append(preds[0],
                                 logits.detach().cpu().numpy(),
                                 axis=0)

    preds = preds[0]

    pred_probs = list(softmax(preds, axis=1)[:, 0])  #entail prob

    assert len(gold_pair_ids) == len(pred_probs)
    assert len(gold_pair_ids) == len(gold_label_ids)

    pairID_2_predgoldlist = {}
    for pair_id, prob, gold_id in zip(gold_pair_ids, pred_probs,
                                      gold_label_ids):
        predgoldlist = pairID_2_predgoldlist.get(pair_id)
        if predgoldlist is None:
            predgoldlist = []
        predgoldlist.append((prob, gold_id))
        pairID_2_predgoldlist[pair_id] = predgoldlist
    total_size = len(pairID_2_predgoldlist)
    hit_size = 0
    for pair_id, predgoldlist in pairID_2_predgoldlist.items():
        predgoldlist.sort(key=lambda x: x[0])  #sort by prob
        assert len(predgoldlist) == 4
        if predgoldlist[-1][1] == 0:
            hit_size += 1
    acc = hit_size / total_size
    print('test acc:', acc)
Exemplo n.º 13
0
def main():
    parser = argparse.ArgumentParser()

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")

    parser.add_argument('--kshot',
                        type=int,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--target_train_batch_size",
                        default=2,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--update_BERT_top_layers',
                        type=int,
                        default=1,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    target_kshot_entail_examples, target_kshot_nonentail_examples, target_dev_examples, target_test_examples = load_FewRel_GFS_Entail(
        args.kshot)

    system_seed = 42
    random.seed(system_seed)
    np.random.seed(system_seed)
    torch.manual_seed(system_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(system_seed)

    source_kshot_size = 10  # if args.kshot>10 else 10 if max(10, args.kshot)
    source_kshot_entail, source_kshot_neural, source_kshot_contra, source_remaining_examples = get_MNLI_train(
        '/export/home/Dataset/glue_data/MNLI/train.tsv', source_kshot_size)
    source_examples = source_kshot_entail + source_kshot_neural + source_kshot_contra + source_remaining_examples
    target_label_list = ["entailment", "non_entailment"]
    source_label_list = ["entailment", "neutral", "contradiction"]
    # entity_label_list = ["A-coref", "B-coref"]
    source_num_labels = len(source_label_list)
    target_num_labels = len(target_label_list)
    print('training size:', len(source_examples), 'dev size:',
          len(target_dev_examples), 'test size:', len(target_test_examples))

    roberta_model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    roberta_model.load_state_dict(torch.load(
        '/export/home/Dataset/BERT_pretrained_mine/MNLI_pretrained/_acc_0.9040886899918633.pt'
    ),
                                  strict=False)
    '''
    embedding layer 5 variables
    each bert layer 16 variables
    '''
    param_size = 0
    update_top_layer_size = args.update_BERT_top_layers
    for name, param in roberta_model.named_parameters():
        if param_size < (5 + 16 * (24 - update_top_layer_size)):
            param.requires_grad = False
        param_size += 1
    roberta_model.to(device)

    protonet = PrototypeNet(bert_hidden_dim)
    protonet.to(device)

    param_optimizer = list(protonet.named_parameters()) + list(
        roberta_model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0

    retrieve_batch_size = 5

    source_kshot_entail_dataloader = examples_to_features(
        source_kshot_entail,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_kshot_neural_dataloader = examples_to_features(
        source_kshot_neural,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_kshot_contra_dataloader = examples_to_features(
        source_kshot_contra,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_remain_ex_dataloader = examples_to_features(
        source_remaining_examples,
        source_label_list,
        args,
        tokenizer,
        args.train_batch_size,
        "classification",
        dataloader_mode='random')

    target_kshot_entail_dataloader = examples_to_features(
        target_kshot_entail_examples,
        target_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    target_kshot_nonentail_dataloader = examples_to_features(
        target_kshot_nonentail_examples,
        target_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    target_dev_dataloader = examples_to_features(target_dev_examples,
                                                 target_label_list,
                                                 args,
                                                 tokenizer,
                                                 args.eval_batch_size,
                                                 "classification",
                                                 dataloader_mode='sequential')
    target_test_dataloader = examples_to_features(target_test_examples,
                                                  target_label_list,
                                                  args,
                                                  tokenizer,
                                                  args.eval_batch_size,
                                                  "classification",
                                                  dataloader_mode='sequential')
    '''starting to train'''
    iter_co = 0
    tr_loss = 0
    source_loss = 0
    target_loss = 0
    final_test_performance = 0.0
    for _ in trange(int(args.num_train_epochs), desc="Epoch"):

        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(
                tqdm(source_remain_ex_dataloader, desc="Iteration")):
            protonet.train()
            batch = tuple(t.to(device) for t in batch)
            _, input_ids, input_mask, segment_ids, source_label_ids_batch = batch

            roberta_model.train()
            source_last_hidden_batch, _ = roberta_model(input_ids, input_mask)
            '''
            retrieve rep for support examples in MNLI
            '''
            kshot_entail_reps = torch.zeros(1, bert_hidden_dim).to(device)
            entail_batch_i = 0
            for entail_batch in source_kshot_entail_dataloader:
                roberta_model.train()
                last_hidden_entail, _ = roberta_model(
                    entail_batch[1].to(device), entail_batch[2].to(device))
                kshot_entail_reps += torch.mean(last_hidden_entail,
                                                dim=0,
                                                keepdim=True)
                entail_batch_i += 1
            kshot_entail_rep = kshot_entail_reps / entail_batch_i
            kshot_neural_reps = torch.zeros(1, bert_hidden_dim).to(device)
            neural_batch_i = 0
            for neural_batch in source_kshot_neural_dataloader:
                roberta_model.train()
                last_hidden_neural, _ = roberta_model(
                    neural_batch[1].to(device), neural_batch[2].to(device))
                kshot_neural_reps += torch.mean(last_hidden_neural,
                                                dim=0,
                                                keepdim=True)
                neural_batch_i += 1
            kshot_neural_rep = kshot_neural_reps / neural_batch_i
            kshot_contra_reps = torch.zeros(1, bert_hidden_dim).to(device)
            contra_batch_i = 0
            for contra_batch in source_kshot_contra_dataloader:
                roberta_model.train()
                last_hidden_contra, _ = roberta_model(
                    contra_batch[1].to(device), contra_batch[2].to(device))
                kshot_contra_reps += torch.mean(last_hidden_contra,
                                                dim=0,
                                                keepdim=True)
                contra_batch_i += 1
            kshot_contra_rep = kshot_contra_reps / contra_batch_i

            source_class_prototype_reps = torch.cat(
                [kshot_entail_rep, kshot_neural_rep, kshot_contra_rep],
                dim=0)  #(3, hidden)
            '''first get representations for support examples in target'''
            target_kshot_entail_dataloader_subset = examples_to_features(
                random.sample(target_kshot_entail_examples, 10),
                target_label_list,
                args,
                tokenizer,
                retrieve_batch_size,
                "classification",
                dataloader_mode='sequential')
            target_kshot_nonentail_dataloader_subset = examples_to_features(
                random.sample(target_kshot_nonentail_examples, 10),
                target_label_list,
                args,
                tokenizer,
                retrieve_batch_size,
                "classification",
                dataloader_mode='sequential')
            kshot_entail_reps = []
            for entail_batch in target_kshot_entail_dataloader_subset:
                roberta_model.train()
                last_hidden_entail, _ = roberta_model(
                    entail_batch[1].to(device), entail_batch[2].to(device))
                kshot_entail_reps.append(last_hidden_entail)
            all_kshot_entail_reps = torch.cat(kshot_entail_reps, dim=0)
            kshot_entail_rep = torch.mean(all_kshot_entail_reps,
                                          dim=0,
                                          keepdim=True)
            kshot_nonentail_reps = []
            for nonentail_batch in target_kshot_nonentail_dataloader_subset:
                roberta_model.train()
                last_hidden_nonentail, _ = roberta_model(
                    nonentail_batch[1].to(device),
                    nonentail_batch[2].to(device))
                kshot_nonentail_reps.append(last_hidden_nonentail)
            all_kshot_neural_reps = torch.cat(kshot_nonentail_reps, dim=0)
            kshot_nonentail_rep = torch.mean(all_kshot_neural_reps,
                                             dim=0,
                                             keepdim=True)
            target_class_prototype_reps = torch.cat(
                [kshot_entail_rep, kshot_nonentail_rep, kshot_nonentail_rep],
                dim=0)  #(3, hidden)

            class_prototype_reps = torch.cat(
                [source_class_prototype_reps, target_class_prototype_reps],
                dim=0)  #(6, hidden)
            '''forward to model'''

            target_batch_size = args.target_train_batch_size  #10*3
            # print('target_batch_size:', target_batch_size)
            target_batch_size_entail = target_batch_size  #random.randrange(5)+1
            target_batch_size_neural = target_batch_size  #random.randrange(5)+1

            selected_target_entail_rep = all_kshot_entail_reps[torch.randperm(
                all_kshot_entail_reps.shape[0])[:target_batch_size_entail]]
            # print('selected_target_entail_rep:', selected_target_entail_rep.shape)
            selected_target_neural_rep = all_kshot_neural_reps[torch.randperm(
                all_kshot_neural_reps.shape[0])[:target_batch_size_neural]]
            # print('selected_target_neural_rep:', selected_target_neural_rep.shape)
            target_last_hidden_batch = torch.cat(
                [selected_target_entail_rep, selected_target_neural_rep])

            last_hidden_batch = torch.cat(
                [source_last_hidden_batch, target_last_hidden_batch],
                dim=0)  #(train_batch_size+10*2)
            # print('last_hidden_batch shape:', last_hidden_batch.shape)
            batch_logits = protonet(class_prototype_reps, last_hidden_batch)
            # exit(0)
            '''source side loss'''
            # loss_fct = CrossEntropyLoss(reduction='none')
            loss_fct = CrossEntropyLoss()
            source_loss_list = loss_fct(
                batch_logits[:source_last_hidden_batch.shape[0]].view(
                    -1, source_num_labels), source_label_ids_batch.view(-1))
            '''target side loss'''
            target_label_ids_batch = torch.tensor(
                [0] * selected_target_entail_rep.shape[0] +
                [1] * selected_target_neural_rep.shape[0],
                dtype=torch.long)
            target_batch_logits = batch_logits[-target_last_hidden_batch.
                                               shape[0]:]
            target_loss_list = loss_by_logits_and_2way_labels(
                target_batch_logits, target_label_ids_batch.view(-1), device)
            # target_loss_list = loss_fct(target_batch_logits.view(-1, source_num_labels), target_label_ids_batch.to(device).view(-1))
            loss = source_loss_list + target_loss_list  #torch.mean(torch.cat([source_loss_list, target_loss_list]))
            source_loss += source_loss_list
            target_loss += target_loss_list
            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            global_step += 1
            iter_co += 1
            # print('iter_co:', iter_co, 'mean loss:', tr_loss/iter_co)
            if iter_co % 20 == 0:
                # if iter_co % len(source_remain_ex_dataloader)==0:
                '''
                start evaluate on dev set after this epoch
                '''
                '''
                retrieve rep for support examples in MNLI
                '''
                kshot_entail_reps = torch.zeros(1, bert_hidden_dim).to(device)
                entail_batch_i = 0
                for entail_batch in source_kshot_entail_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_entail, _ = roberta_model(
                            entail_batch[1].to(device),
                            entail_batch[2].to(device))
                    kshot_entail_reps += torch.mean(last_hidden_entail,
                                                    dim=0,
                                                    keepdim=True)
                    entail_batch_i += 1
                kshot_entail_rep = kshot_entail_reps / entail_batch_i
                kshot_neural_reps = torch.zeros(1, bert_hidden_dim).to(device)
                neural_batch_i = 0
                for neural_batch in source_kshot_neural_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_neural, _ = roberta_model(
                            neural_batch[1].to(device),
                            neural_batch[2].to(device))
                    kshot_neural_reps += torch.mean(last_hidden_neural,
                                                    dim=0,
                                                    keepdim=True)
                    neural_batch_i += 1
                kshot_neural_rep = kshot_neural_reps / neural_batch_i
                kshot_contra_reps = torch.zeros(1, bert_hidden_dim).to(device)
                contra_batch_i = 0
                for contra_batch in source_kshot_contra_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_contra, _ = roberta_model(
                            contra_batch[1].to(device),
                            contra_batch[2].to(device))
                    kshot_contra_reps += torch.mean(last_hidden_contra,
                                                    dim=0,
                                                    keepdim=True)
                    contra_batch_i += 1
                kshot_contra_rep = kshot_contra_reps / contra_batch_i

                source_class_prototype_reps = torch.cat(
                    [kshot_entail_rep, kshot_neural_rep, kshot_contra_rep],
                    dim=0)  #(3, hidden)
                '''first get representations for support examples in target'''
                # target_kshot_entail_dataloader_subset = examples_to_features(random.sample(target_kshot_entail_examples, args.kshot), target_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
                # target_kshot_nonentail_dataloader_subset = examples_to_features(random.sample(target_kshot_nonentail_examples, args.kshot), target_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
                kshot_entail_reps = torch.zeros(1, bert_hidden_dim).to(device)
                entail_batch_i = 0
                for entail_batch in target_kshot_entail_dataloader_subset:  #target_kshot_entail_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_entail, _ = roberta_model(
                            entail_batch[1].to(device),
                            entail_batch[2].to(device))
                    kshot_entail_reps += torch.mean(last_hidden_entail,
                                                    dim=0,
                                                    keepdim=True)
                    entail_batch_i += 1
                kshot_entail_rep = kshot_entail_reps / entail_batch_i
                kshot_nonentail_reps = torch.zeros(1,
                                                   bert_hidden_dim).to(device)
                nonentail_batch_i = 0
                for nonentail_batch in target_kshot_nonentail_dataloader_subset:  #target_kshot_nonentail_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_nonentail, _ = roberta_model(
                            nonentail_batch[1].to(device),
                            nonentail_batch[2].to(device))
                    kshot_nonentail_reps += torch.mean(last_hidden_nonentail,
                                                       dim=0,
                                                       keepdim=True)
                    nonentail_batch_i += 1
                kshot_nonentail_rep = kshot_nonentail_reps / nonentail_batch_i
                target_class_prototype_reps = torch.cat([
                    kshot_entail_rep, kshot_nonentail_rep, kshot_nonentail_rep
                ],
                                                        dim=0)  #(3, hidden)

                class_prototype_reps = torch.cat(
                    [source_class_prototype_reps, target_class_prototype_reps],
                    dim=0)  #(6, hidden)

                protonet.eval()

                # dev_acc = evaluation(protonet, target_dev_dataloader,  device, flag='Dev')
                # print('class_prototype_reps:', class_prototype_reps)
                dev_acc = evaluation(protonet,
                                     roberta_model,
                                     class_prototype_reps,
                                     target_dev_dataloader,
                                     device,
                                     flag='Dev')
                if dev_acc > max_dev_acc:
                    max_dev_acc = dev_acc
                    print('\n\t dev acc:', dev_acc, ' max_dev_acc:',
                          max_dev_acc, '\n')
                    if dev_acc > 0.73:  #10:0.73; 5:0.66
                        test_acc = evaluation(protonet,
                                              roberta_model,
                                              class_prototype_reps,
                                              target_test_dataloader,
                                              device,
                                              flag='Test')
                        if test_acc > max_test_acc:
                            max_test_acc = test_acc

                        final_test_performance = test_acc
                        print('\n\t test acc:', test_acc, ' max_test_acc:',
                              max_test_acc, '\n')
                else:
                    print('\n\t dev acc:', dev_acc, ' max_dev_acc:',
                          max_dev_acc, '\n')

            if iter_co == 2000:
                break
    print('final_test_performance:', final_test_performance)
Exemplo n.º 14
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--round_name",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=50,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")


    args = parser.parse_args()


    processors = {
        "rte": RteProcessor
    }

    output_modes = {
        "rte": "classification"
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")


    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    round_name_2_rounds={'base':['base', 'ood'],
                         'r1':['base', 'n1', 'ood'],
                         'r2':['base', 'n1', 'n2', 'ood'],
                         'r3':['base', 'n1', 'n2', 'n3', 'ood'],
                         'r4':['base', 'n1', 'n2', 'n3','n4', 'ood'],
                         'r5':['base', 'n1', 'n2', 'n3','n4', 'n5', 'ood']}




    model = RobertaForSequenceClassification(2) #10 is a random number, can be changed
    tokenizer = RobertaTokenizer.from_pretrained(pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    optimizer = AdamW(optimizer_grouped_parameters,
                             lr=args.learning_rate)

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    banking77_class_list, ood_class_set, class_2_split = load_class_names()

    round_list = round_name_2_rounds.get(args.round_name)
    '''support for all seen classes'''
    class_2_support_examples = processor.load_support_all_rounds(round_list[:-1]) #no support set for ood
    seen_class_list = list(class_2_support_examples.keys())
    support_example_lists = [class_2_support_examples.get(seen_class)  for seen_class in seen_class_list]
    '''dev and test'''
    dev_examples, dev_class_list = processor.load_dev_or_test(round_list, 'dev')
    test_examples, test_class_list = processor.load_dev_or_test(round_list, 'test')
    assert len(set(dev_class_list) | set(test_class_list)) == len(set(dev_class_list))
    assert len(test_class_list) == len(seen_class_list)+7
    eval_class_list = seen_class_list+list(ood_class_set)

    test_split_list = []
    for test_class_i in eval_class_list:
        test_split_list.append(class_2_split.get(test_class_i))

    eval_support_dataloader_list = []

    for eval_support_examples_per_class in support_example_lists:
        support_dataloader = examples_to_features(eval_support_examples_per_class, seen_class_list, args, tokenizer, 5, "classification", dataloader_mode='random')
        eval_support_dataloader_list.append(support_dataloader)
    dev_query_dataloader = examples_to_features(dev_examples, eval_class_list, args, tokenizer, args.eval_batch_size, "classification", dataloader_mode='random')
    test_query_dataloader = examples_to_features(test_examples, eval_class_list, args, tokenizer, args.eval_batch_size, "classification", dataloader_mode='random')


    '''training'''
    max_test_acc = 0.0
    max_dev_acc = 0.0

    for _ in range(args.num_train_epochs):
        train_support_examples_list, full_query_examples, selected_class_list = processor.load_base_train() #we do not use ood as training
        train_support_dataloader_list = []

        for train_support_examples_per_class in train_support_examples_list:
            train_support_dataloader = examples_to_features(train_support_examples_per_class, selected_class_list, args, tokenizer, 5, "classification", dataloader_mode='random')
            train_support_dataloader_list.append(train_support_dataloader)
        train_query_dataloader = examples_to_features(full_query_examples, selected_class_list, args, tokenizer, args.train_batch_size, "classification", dataloader_mode='random')


        '''then compute rep for query batch'''
        best_threshold = []
        for _, batch in enumerate(tqdm(train_query_dataloader, desc="train")):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, _, label_ids = batch

            last_hidden_batch, _ = model(input_ids, input_mask) #(batch, hidden)
            query_normalized_rep = last_hidden_batch/(1e-8+torch.sqrt(torch.sum(torch.square(last_hidden_batch), axis=1, keepdim=True)))
            '''first compute class prototype rep'''
            all_class_proto_reps = []
            for train_support_dataloader in train_support_dataloader_list:
                class_reps = torch.zeros(1, bert_hidden_dim).to(device)
                batch_size_accu = 0
                for batch in train_support_dataloader:
                    batch = tuple(t.to(device) for t in batch)
                    input_ids, input_mask, _, _ = batch

                    last_hidden_batch, _ = model(input_ids, input_mask)
                    class_reps+=torch.mean(last_hidden_batch, axis=0, keepdim=True)
                    batch_size_accu+=1
                class_rep = class_reps/batch_size_accu
                all_class_proto_reps.append(class_rep)
            all_class_proto_reps = torch.cat(all_class_proto_reps, axis=0) #(#class, hidden)
            support_normalized_rep = all_class_proto_reps/(1e-8+torch.sqrt(torch.sum(torch.square(all_class_proto_reps), axis=1, keepdim=True)))
            #cosine

            logits = torch.mm(query_normalized_rep, torch.transpose(support_normalized_rep, 0, 1)) #(batch, class)

            loss_fct = CrossEntropyLoss()

            loss = loss_fct(logits.view(-1, 10), label_ids.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scores_for_positive = logits[torch.arange(logits.shape[0]), label_ids.view(-1)].mean()
            best_threshold.append(scores_for_positive.item())

        best_threshold = sum(best_threshold) / len(best_threshold)
        print('best_threshold:', best_threshold )

        '''evaluation'''
        model.eval()
        '''first get class proty rep'''
        all_class_proto_reps = []
        for support_dataloader in eval_support_dataloader_list:
            class_reps = torch.zeros(1, bert_hidden_dim).to(device)
            batch_size_accu = 0
            for batch in support_dataloader:
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, _, _ = batch
                with torch.no_grad():
                    last_hidden_batch, _ = model(input_ids, input_mask)
                class_reps+=torch.mean(last_hidden_batch, axis=0, keepdim=True)
                batch_size_accu+=1
            class_rep = class_reps/batch_size_accu
            all_class_proto_reps.append(class_rep)
        all_class_proto_reps = torch.cat(all_class_proto_reps, axis=0) #(#class, hidden)
        support_normalized_rep = all_class_proto_reps/(1e-8+torch.sqrt(torch.sum(torch.square(all_class_proto_reps), axis=1, keepdim=True)))


        logger.info("***** Running test *****")
        logger.info("  Num examples = %d", len(test_examples))

        preds = []
        gold_label_ids = []
        # print('Evaluating...')
        for input_ids, input_mask, segment_ids, label_ids in test_query_dataloader:
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            # segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)
            gold_label_ids+=list(label_ids.detach().cpu().numpy())

            with torch.no_grad():
                last_hidden_batch, _ = model(input_ids, input_mask)
            query_normalized_rep = last_hidden_batch/(1e-8+torch.sqrt(torch.sum(torch.square(last_hidden_batch), axis=1, keepdim=True)))
            logits = torch.mm(query_normalized_rep, torch.transpose(support_normalized_rep, 0, 1)) #(batch, class)
            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)

        preds = preds[0]

        pred_probs = preds#softmax(preds,axis=1)
        pred_label_ids_raw = list(np.argmax(pred_probs, axis=1))
        pred_max_prob = list(np.amax(pred_probs, axis=1))

        pred_label_ids = []
        for i, pred_max_prob_i in enumerate(pred_max_prob):
            if pred_max_prob_i < best_threshold:
                pred_label_ids.append(-1) #-1 means ood
            else:
                pred_label_ids.append(pred_label_ids_raw[i])


        gold_label_ids = gold_label_ids
        assert len(pred_label_ids) == len(gold_label_ids)
        acc_each_round = []
        for round_name_id in round_list:
            #base, n1, n2, ood
            round_size = 0
            rount_hit = 0
            if round_name_id != 'ood':
                for ii, gold_label_id in enumerate(gold_label_ids):
                    if test_split_list[gold_label_id] == round_name_id:
                        round_size+=1
                        if gold_label_id == pred_label_ids[ii]:
                            rount_hit+=1
                acc_i = rount_hit/round_size
                acc_each_round.append(acc_i)
            else:
                '''ood f1'''
                gold_binary_list = []
                pred_binary_list = []
                for ii, gold_label_id in enumerate(gold_label_ids):
                    gold_binary_list.append(1 if test_split_list[gold_label_id] == round_name_id else 0)
                    pred_binary_list.append(1 if pred_label_ids[ii]==-1 else 0)
                overlap = 0
                for i in range(len(gold_binary_list)):
                    if gold_binary_list[i] == 1 and pred_binary_list[i]==1:
                        overlap +=1
                recall = overlap/(1e-6+sum(gold_binary_list))
                precision = overlap/(1e-6+sum(pred_binary_list))

                acc_i = 2*recall*precision/(1e-6+recall+precision)
                acc_each_round.append(acc_i)

        print('\n\t\t test_acc:', acc_each_round)
        final_test_performance = acc_each_round

    print('final_test_performance:', final_test_performance)
 def setup_base_tokenizer(self):
     self.base_tokenizer = RobertaTokenizer.from_pretrained('distilroberta-base', do_lower_case=True,
                                                            cache_dir=self.test_dir)
    def setup_class(self):
        self.use_gpu = torch.cuda.is_available()
        self.test_dir = Path(tempfile.mkdtemp())

        self.base_tokenizer = RobertaTokenizer.from_pretrained('distilroberta-base', do_lower_case=True,
                                                               cache_dir=self.test_dir)
        self.rust_tokenizer = PyRobertaTokenizer(
            get_from_cache(self.base_tokenizer.pretrained_vocab_files_map['vocab_file']['distilroberta-base']),
            get_from_cache(self.base_tokenizer.pretrained_vocab_files_map['merges_file']['distilroberta-base']),
            do_lower_case=True
        )
        self.model = RobertaModel.from_pretrained('distilroberta-base',
                                                  output_attentions=False).eval()
        if self.use_gpu:
            self.model.cuda()
        #     Extracted from https://en.wikipedia.org/wiki/Deep_learning
        self.sentence_list = [
            'Deep learning (also known as deep structured learning or hierarchical learning) is part of a broader family of machine learning methods based on artificial neural networks.Learning can be supervised, semi-supervised or unsupervised.',
            'Deep learning is a class of machine learning algorithms that[11](pp199–200) uses multiple layers to progressively extract higher level features from the raw input.',
            'For example, in image processing, lower layers may identify edges, while higher layers may identify the concepts relevant to a human such as digits or letters or faces.',
            'Most modern deep learning models are based on artificial neural networks, specifically, Convolutional Neural Networks (CNN)s, although they can also include propositional formulas organized layer-wise in deep generative models.',
            'In deep learning, each level learns to transform its input data into a slightly more abstract and composite representation.',
            'In an image recognition application, the raw input may be a matrix of pixels; the first representational layer may abstract the pixels and encode edges; the second layer may compose and encode arrangements of edges;',
            'he third layer may encode a nose and eyes; and the fourth layer may recognize that the image contains a face. Importantly, a deep learning process can learn which features to optimally place in which level on its own.',
            '(Of course, this does not completely eliminate the need for hand-tuning; for example, varying numbers of layers and layer sizes can provide different degrees of abstraction.)[',
            'The word "deep" in "deep learning" refers to the number of layers through which the data is transformed. More precisely, deep learning systems have a substantial credit assignment path (CAP) depth. The CAP is the chain of transformations from input to output.',
            'CAPs describe potentially causal connections between input and output. For a feedforward neural network, the depth of the CAPs is that of the network and is the number of hidden layers plus one (as the output layer is also parameterized).',
            'For recurrent neural networks, in which a signal may propagate through a layer more than once, the CAP depth is potentially unlimited.[2] No universally agreed upon threshold of depth divides shallow learning from deep learning.',
            'CAP of depth 2 has been shown to be a universal approximator in the sense that it can emulate any function.[14] Beyond that, more layers do not add to the function approximator ability of the network.',
            'Deep models (CAP > 2) are able to extract better features than shallow models and hence, extra layers help in learning the features effectively. Deep learning architectures can be constructed with a greedy layer-by-layer method.',
            'Deep learning helps to disentangle these abstractions and pick out which features improve performance.[1]. For supervised learning tasks, deep learning methods eliminate feature engineering, by translating the data into compact intermediate representations',
            'Deep learning algorithms can be applied to unsupervised learning tasks. This is an important benefit because unlabeled data are more abundant than the labeled data. Examples of deep structures that can be trained in an unsupervised manner are neural history compressors and deep belief networks.',
            'Deep neural networks are generally interpreted in terms of the universal approximation theorem or probabilistic inference. The classic universal approximation theorem concerns the capacity of feedforward neural networks with a single hidden layer of finite size to approximate continuous functions.',
            'In 1989, the first proof was published by George Cybenko for sigmoid activation functions and was generalised to feed-forward multi-layer architectures in 1991 by Kurt Hornik.Recent work also showed that universal approximation also holds for non-bounded activation functions such as the rectified linear unit.',
            'he universal approximation theorem for deep neural networks concerns the capacity of networks with bounded width but the depth is allowed to grow. Lu et al. proved that if the width of a deep neural network with ReLU activation is strictly larger than the input dimension, then the network can approximate any Lebesgue integrable function',
            'The probabilistic interpretation[24] derives from the field of machine learning. It features inference, as well as the optimization concepts of training and testing, related to fitting and generalization, respectively',
            'More specifically, the probabilistic interpretation considers the activation nonlinearity as a cumulative distribution function. The probabilistic interpretation led to the introduction of dropout as regularizer in neural networks.',
            'The probabilistic interpretation was introduced by researchers including Hopfield, Widrow and Narendra and popularized in surveys such as the one by Bishop. The term Deep Learning was introduced to the machine learning community by Rina Dechter in 1986',
            'The first general, working learning algorithm for supervised, deep, feedforward, multilayer perceptrons was published by Alexey Ivakhnenko and Lapa in 1965.[32] A 1971 paper described already a deep network with 8 layers trained by the group method of data handling algorithm.',
            'Other deep learning working architectures, specifically those built for computer vision, began with the Neocognitron introduced by Kunihiko Fukushima in 1980.[34] In 1989, Yann LeCun et al. applied the standard backpropagation algorithm',
            'By 1991 such systems were used for recognizing isolated 2-D hand-written digits, while recognizing 3-D objects was done by matching 2-D images with a handcrafted 3-D object model. Weng et al. suggested that a human brain does not use a monolithic 3-D object model and in 1992 they published Cresceptron',
            'Because it directly used natural images, Cresceptron started the beginning of general-purpose visual learning for natural 3D worlds. Cresceptron is a cascade of layers similar to Neocognitron. But while Neocognitron required a human programmer to hand-merge features, Cresceptron learned an open number of features in each layer without supervision',
            'Cresceptron segmented each learned object from a cluttered scene through back-analysis through the network. Max pooling, now often adopted by deep neural networks (e.g. ImageNet tests), was first used in Cresceptron to reduce the position resolution by a factor of (2x2) to 1 through the cascade for better generalization',
            'In 1994, André de Carvalho, together with Mike Fairhurst and David Bisset, published experimental results of a multi-layer boolean neural network, also known as a weightless neural network, composed of a 3-layers self-organising feature extraction neural network module (SOFT) followed by a multi-layer classification neural network module (GSN)',
            'n 1995, Brendan Frey demonstrated that it was possible to train a network containing six fully connected layers and several hundred hidden units using the wake-sleep algorithm, co-developed with Peter Dayan and Hinton. Many factors contribute to the slow speed, including the vanishing gradient problem analyzed in 1991 by Sepp Hochreiter',
            'Simpler models that use task-specific handcrafted features such as Gabor filters and support vector machines (SVMs) were a popular choice in the 1990s and 2000s, because of artificial neural network\'s (ANN) computational cost and a lack of understanding of how the brain wires its biological networks.',
            'Both shallow and deep learning (e.g., recurrent nets) of ANNs have been explored for many years.[47][48][49] These methods never outperformed non-uniform internal-handcrafting Gaussian mixture model/Hidden Markov model (GMM-HMM) technology based on generative models of speech trained discriminatively.',
            'Key difficulties have been analyzed, including gradient diminishing[45] and weak temporal correlation structure in neural predictive models.[51][52] Additional difficulties were the lack of training data and limited computing power. Most speech recognition researchers moved away from neural nets to pursue generative modeling.',
            'An exception was at SRI International in the late 1990s. Funded by the US government\'s NSA and DARPA, SRI studied deep neural networks in speech and speaker recognition. The speaker recognition team led by Larry Heck achieved the first significant success with deep neural networks.',
            'While SRI experienced success with deep neural networks in speaker recognition, they were unsuccessful in demonstrating similar success in speech recognition. The principle of elevating "raw" features over hand-crafted optimization was first explored successfully in the architecture of deep autoencoder on the "raw" spectrogram'
        ]

        # Pre-allocate GPU memory
        tokens_list = [self.base_tokenizer.tokenize(sentence) for sentence in self.sentence_list]
        features = [self.base_tokenizer.convert_tokens_to_ids(tokens) for tokens in tokens_list]
        features = [self.base_tokenizer.prepare_for_model(input, None, add_special_tokens=True, max_length=128) for
                    input in features]
        max_len = max([len(f['input_ids']) for f in features])
        features = [f['input_ids'] + [0] * (max_len - len(f['input_ids'])) for f in features]
        all_input_ids = torch.tensor(features, dtype=torch.long)

        if self.use_gpu:
            all_input_ids = all_input_ids.cuda()

        with torch.no_grad():
            _ = self.model(all_input_ids)[0].cpu().numpy()
Exemplo n.º 17
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")

    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    '''first tune threshold on validation set, then try it on test set'''
    # test_examples = processor.get_GAP_coreference_as_test('gap-validation.tsv')
    test_examples = processor.get_GAP_coreference_as_test('gap-test.tsv')
    label_list = ["A-coref", "B-coref"]
    # train_examples = get_data_hulu_fewshot('train', 5)

    num_labels = len(label_list)
    print('num_labels:', num_labels, 'test size:', len(test_examples))

    model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.load_state_dict(torch.load(
        '/export/home/Dataset/BERT_pretrained_mine/MNLI_pretrained/_acc_0.9040886899918633.pt'
    ),
                          strict=False)
    model.to(device)
    '''load test set'''
    test_features = convert_examples_to_features(
        test_examples,
        label_list,
        args.max_seq_length,
        tokenizer,
        output_mode,
        cls_token_at_end=
        False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
        cls_token=tokenizer.cls_token,
        cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
        sep_token=tokenizer.sep_token,
        sep_token_extra=
        True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
        pad_on_left=
        False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
        pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
        pad_token_segment_id=0)  #4 if args.model_type in ['xlnet'] else 0,)

    eval_all_input_indices = torch.tensor([f.id for f in test_features],
                                          dtype=torch.long)
    eval_all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                      dtype=torch.long)
    eval_all_input_mask = torch.tensor([f.input_mask for f in test_features],
                                       dtype=torch.long)
    eval_all_segment_ids = torch.tensor([f.segment_ids for f in test_features],
                                        dtype=torch.long)
    eval_all_label_ids = torch.tensor([f.label_id for f in test_features],
                                      dtype=torch.long)

    eval_data = TensorDataset(eval_all_input_indices, eval_all_input_ids,
                              eval_all_input_mask, eval_all_segment_ids,
                              eval_all_label_ids)
    eval_sampler = SequentialSampler(eval_data)
    test_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    model.eval()

    logger.info("***** Running test *****")
    logger.info("  Num examples = %d", len(test_examples))
    # logger.info("  Batch size = %d", args.eval_batch_size)

    eval_loss = 0
    nb_eval_steps = 0
    preds = []
    gold_label_ids = []
    example_id_list = []
    for _, batch in enumerate(tqdm(test_dataloader, desc="test")):
        input_indices, input_ids, input_mask, segment_ids, label_ids = batch
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        label_ids = label_ids.to(device)
        example_ids = list(input_indices.numpy())
        example_id_list += example_ids
        gold_label_ids += list(label_ids.detach().cpu().numpy())

        with torch.no_grad():
            logits = model(input_ids, input_mask)
        if len(preds) == 0:
            preds.append(logits.detach().cpu().numpy())
        else:
            preds[0] = np.append(preds[0],
                                 logits.detach().cpu().numpy(),
                                 axis=0)

    preds = preds[0]

    pred_probs = softmax(preds, axis=1)
    pred_label_ids_3way = list(np.argmax(pred_probs, axis=1))
    pred_prob_entail = list(pred_probs[:, 0])

    assert len(example_id_list) == len(pred_prob_entail)
    assert len(example_id_list) == len(gold_label_ids)
    assert len(example_id_list) == len(pred_label_ids_3way)
    # writefile = codecs.open()

    for threshold in np.arange(0.99, 0.0, -0.01):
        threshold = 0.7399999999999998  #best on dev

        id2labellist = {}
        id2scorelist = {}
        for ex_id, type, prob, entail_or_not in zip(example_id_list,
                                                    gold_label_ids,
                                                    pred_prob_entail,
                                                    pred_label_ids_3way):
            labellist = id2labellist.get(ex_id)
            scorelist = id2scorelist.get(ex_id)
            if scorelist is None:
                scorelist = [0.0, 0.0]
            scorelist[type] = prob
            if labellist is None:
                labellist = ['', '']
            if prob > threshold:
                labellist[type] = True
            else:
                labellist[type] = False
            id2labellist[ex_id] = labellist
            id2scorelist[ex_id] = scorelist
        '''remove conflict'''
        eval_output_list = []
        # prefix = 'validation-' #'test-'
        prefix = 'test-'  #'test-'
        for ex_id, labellist in id2labellist.items():
            if labellist[0] is True and labellist[1] is True:
                scorelist = id2scorelist.get(ex_id)
                # print('scorelist:', scorelist)
                if scorelist[0] > scorelist[1]:
                    eval_output_list.append([prefix + str(ex_id), True, False])
                else:
                    eval_output_list.append([prefix + str(ex_id), False, True])
            else:
                eval_output_list.append([prefix + str(ex_id)] + labellist)

        # test_acc = run_scorer('/export/home/Dataset/gap_coreference/gap-validation.tsv', eval_output_list)
        test_acc = run_scorer(
            '/export/home/Dataset/gap_coreference/gap-test.tsv',
            eval_output_list)
        print('threshold:', threshold, 'test_f1:', test_acc)
        exit(0)
Exemplo n.º 18
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--train_corpus', type=Path, required=True)
    parser.add_argument("--output_dir", type=Path, required=True)
    parser.add_argument("--bert_model",
                        type=str,
                        required=True,
                        choices=[
                            "bert-base-uncased", "bert-large-uncased",
                            "bert-base-cased",
                            "bert-base-multilingual-uncased",
                            "bert-base-chinese",
                            "bert-base-multilingual-cased", "roberta-base"
                        ])
    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument(
        "--do_whole_word_mask",
        action="store_true",
        help=
        "Whether to use whole word masking rather than per-WordPiece masking.")
    parser.add_argument(
        "--reduce_memory",
        action="store_true",
        help=
        "Reduce memory usage for large datasets by keeping data on disc rather than in memory"
    )

    parser.add_argument("--num_workers",
                        type=int,
                        default=1,
                        help="The number of workers to use to write the files")
    parser.add_argument("--epochs_to_generate",
                        type=int,
                        default=1,
                        help="Number of epochs of data to pregenerate")
    parser.add_argument("--max_seq_len", type=int, default=128)
    parser.add_argument(
        "--short_seq_prob",
        type=float,
        default=0.1,
        help="Probability of making a short sentence as a training example")
    parser.add_argument(
        "--masked_lm_prob",
        type=float,
        default=0.15,
        help="Probability of masking each token for the LM task")
    parser.add_argument(
        "--max_predictions_per_seq",
        type=int,
        default=20,
        help="Maximum number of tokens to mask in each sequence")

    args = parser.parse_args()

    if args.num_workers > 1 and args.reduce_memory:
        raise ValueError("Cannot use multiple workers while reducing memory")

    if args.bert_model != "roberta-base":
        tokenizer = BertTokenizer.from_pretrained(
            args.bert_model, do_lower_case=args.do_lower_case)
    else:
        tokenizer = RobertaTokenizer.from_pretrained(
            args.bert_model, do_lower_case=args.do_lower_case)
    vocab_list = list(
        tokenizer.vocab.keys()) if args.bert_model != "roberta-base" else list(
            tokenizer.encoder.keys())

    with DocumentDatabase(reduce_memory=args.reduce_memory) as docs:
        with args.train_corpus.open() as f:
            doc = []
            for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
                line = line.strip()

                line_obj = json.loads(line)
                tokens = tokenizer.tokenize(line_obj['masked_sentences'][0])
                line_obj['masked_sentences_tokenized'] = tokens
                docs.add_document(line_obj)

        if len(docs) <= 1:
            exit(
                "ERROR: No document breaks were found in the input file! These are necessary to allow the script to "
                "ensure that random NextSentences are not sampled from the same document. Please add blank lines to "
                "indicate breaks between documents in your input file. If your dataset does not contain multiple "
                "documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, "
                "sections or paragraphs.")

        args.output_dir.mkdir(exist_ok=True)

        if args.num_workers > 1:
            writer_workers = Pool(
                min(args.num_workers, args.epochs_to_generate))
            arguments = [(docs, vocab_list, args, idx)
                         for idx in range(args.epochs_to_generate)]
            writer_workers.starmap(create_training_file, arguments)
        else:
            for epoch in trange(args.epochs_to_generate, desc="Epoch"):
                create_training_file(docs, vocab_list, args, epoch)
def test():
    from pprint import pprint
    roberta_ner_data_processor = TransformerNerDataProcessor()
    conll_2003 = Path(
        __file__).resolve().parent.parent.parent / "test_data/conll-2003"
    roberta_ner_data_processor.set_data_dir(conll_2003)
    labels, label2idx = roberta_ner_data_processor.get_labels(
        default='roberta')
    print(labels, label2idx)

    # train_examples = roberta_ner_data_processor.get_train_examples()
    train_examples = roberta_ner_data_processor.get_test_examples()
    pprint(train_examples[:5], indent=1)
    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
    # tokenizer = XLNetTokenizer.from_pretrained("xlnet-base_uncased")
    features = transformer_convert_data_to_features(train_examples[:5],
                                                    label2idx,
                                                    tokenizer,
                                                    max_seq_len=10)

    model = RobertaNerModel.from_pretrained("roberta-base",
                                            num_labels=len(label2idx))
    # model = XLNetNerModel.from_pretrained("xlnet-base_uncased", num_labels=len(label2idx))

    y_trues, y_preds = [], []
    y_pred, y_true = [], []
    prev_gd = 0
    for idx, each_batch in enumerate(
            ner_data_loader(features, batch_size=5, task='test', auto=True)):
        # [idx*batch_size: (idx+1)*batch_size]
        print([(fea.input_tokens, fea.guards)
               for fea in features[idx * 2:(idx + 1) * 2]])
        print(each_batch)

        original_tkid = each_batch[0].numpy()
        original_mask = each_batch[1].numpy()
        original_labels = each_batch[3].numpy()
        guards = each_batch[4].numpy()
        print(guards)

        inputs = batch_to_model_inputs(each_batch)

        with torch.no_grad():
            logits, flatted_logits, loss = model(**inputs)
            # get softmax output of the raw logits (keep dimensions)
            raw_logits = torch.argmax(torch.nn.functional.log_softmax(logits,
                                                                      dim=2),
                                      dim=2)
            raw_logits = raw_logits.detach().cpu().numpy()

        logits = logits.numpy()
        loss = loss.numpy()

        print(logits.shape)
        # print(loss)

        # tk=token, mk=mask, lb=label, lgt=logits
        for mks, lbs, lgts, gds in zip(original_mask, original_labels,
                                       raw_logits, guards):
            connect_sent_flag = False
            for mk, lb, lgt, gd in zip(mks, lbs, lgts, gds):
                if mk == 0:  # after hit first mask, we can stop for the current sentence since all rest will be pad
                    break
                if gd == 0 or prev_gd == gd:
                    continue
                if gd == -2:
                    connect_sent_flag = True
                    break
                if prev_gd != gd:
                    y_true.append(lb)
                    y_pred.append(lgt)
                    prev_gd = gd
            if connect_sent_flag:
                continue
            y_trues.append(y_true)
            y_preds.append(y_pred)
            y_pred, y_true = [], []
            prev_gd = 0
        print(y_trues)
        print(y_preds)
Exemplo n.º 20
0
 def get_tokenizer(self, **kwargs):
     kwargs.update(self.special_tokens_map)
     return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
Exemplo n.º 21
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, n_gpu)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, n_gpu)
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()

    num_labels = len(["entailment", "neutral", "contradiction"])
    # pretrain_model_dir = 'roberta-large' #'roberta-large' , 'roberta-large-mnli'
    pretrain_model_dir = '/export/home/Dataset/BERT_pretrained_mine/TrainedModelReminder/RoBERTa_on_MNLI_SNLI_SciTail_RTE_ANLI_SpecialToken_epoch_2_acc_4.156359461121103'  #'roberta-large' , 'roberta-large-mnli'
    model = RobertaForSequenceClassification.from_pretrained(
        pretrain_model_dir, num_labels=num_labels)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    #MNLI-SNLI-SciTail-RTE-SICK
    train_examples_MNLI, dev_examples_MNLI = processor.get_MNLI_train_and_dev(
        '/export/home/Dataset/glue_data/MNLI/train.tsv',
        '/export/home/Dataset/glue_data/MNLI/dev_mismatched.tsv'
    )  #train_pu_half_v1.txt
    train_examples_SNLI, dev_examples_SNLI = processor.get_SNLI_train_and_dev(
        '/export/home/Dataset/glue_data/SNLI/train.tsv',
        '/export/home/Dataset/glue_data/SNLI/dev.tsv')
    train_examples_SciTail, dev_examples_SciTail = processor.get_SciTail_train_and_dev(
        '/export/home/Dataset/SciTailV1/tsv_format/scitail_1.0_train.tsv',
        '/export/home/Dataset/SciTailV1/tsv_format/scitail_1.0_dev.tsv')
    train_examples_RTE, dev_examples_RTE = processor.get_RTE_train_and_dev(
        '/export/home/Dataset/glue_data/RTE/train.tsv',
        '/export/home/Dataset/glue_data/RTE/dev.tsv')
    train_examples_ANLI, dev_examples_ANLI = processor.get_ANLI_train_and_dev(
        'train', 'dev',
        '/export/home/Dataset/para_entail_datasets/ANLI/anli_v0.1/')

    train_examples = train_examples_MNLI + train_examples_SNLI + train_examples_SciTail + train_examples_RTE + train_examples_ANLI
    dev_examples_list = [
        dev_examples_MNLI, dev_examples_SNLI, dev_examples_SciTail,
        dev_examples_RTE, dev_examples_ANLI
    ]

    dev_task_label = [0, 0, 1, 1, 0]
    task_names = ['MNLI', 'SNLI', 'SciTail', 'RTE', 'ANLI']
    '''filter challenging neighbors'''
    neighbor_id_list = []
    readfile = codecs.open('neighbors_indices_before_dropout_eud.v3.txt', 'r',
                           'utf-8')
    for line in readfile:
        neighbor_id_list.append(int(line.strip()))
    readfile.close()
    print('neighbor_id_list size:', len(neighbor_id_list))
    truncated_train_examples = [train_examples[i] for i in neighbor_id_list]
    train_examples = truncated_train_examples

    num_train_optimization_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps) * args.num_train_epochs
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0

    train_features = convert_examples_to_features(
        train_examples,
        label_list,
        args.max_seq_length,
        tokenizer,
        output_mode,
        cls_token_at_end=
        False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
        cls_token=tokenizer.cls_token,
        cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
        sep_token=tokenizer.sep_token,
        sep_token_extra=
        True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
        pad_on_left=
        False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
        pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
        pad_token_segment_id=0)  #4 if args.model_type in ['xlnet'] else 0,)

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_examples))
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Num steps = %d", num_train_optimization_steps)
    all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                   dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in train_features],
                                 dtype=torch.long)
    all_task_label_ids = torch.tensor([f.task_label for f in train_features],
                                      dtype=torch.long)

    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                               all_label_ids, all_task_label_ids)
    train_sampler = RandomSampler(train_data)

    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  drop_last=True)
    '''dev data to features'''
    valid_dataloader_list = []
    for valid_examples_i in dev_examples_list:
        valid_features = convert_examples_to_features(
            valid_examples_i,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        logger.info("***** valid_examples *****")
        logger.info("  Num examples = %d", len(valid_examples_i))
        valid_input_ids = torch.tensor([f.input_ids for f in valid_features],
                                       dtype=torch.long)
        valid_input_mask = torch.tensor([f.input_mask for f in valid_features],
                                        dtype=torch.long)
        valid_segment_ids = torch.tensor(
            [f.segment_ids for f in valid_features], dtype=torch.long)
        valid_label_ids = torch.tensor([f.label_id for f in valid_features],
                                       dtype=torch.long)
        valid_task_label_ids = torch.tensor(
            [f.task_label for f in valid_features], dtype=torch.long)

        valid_data = TensorDataset(valid_input_ids, valid_input_mask,
                                   valid_segment_ids, valid_label_ids,
                                   valid_task_label_ids)
        valid_sampler = SequentialSampler(valid_data)
        valid_dataloader = DataLoader(valid_data,
                                      sampler=valid_sampler,
                                      batch_size=args.eval_batch_size)
        valid_dataloader_list.append(valid_dataloader)

    iter_co = 0
    for epoch_i in trange(int(args.num_train_epochs), desc="Epoch"):
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids, task_label_ids = batch
            logits = model(input_ids, input_mask, None, labels=None)

            prob_matrix = F.log_softmax(logits[0].view(-1, num_labels), dim=1)
            '''this step *1.0 is very important, otherwise bug'''
            new_prob_matrix = prob_matrix * 1.0
            '''change the entail prob to p or 1-p'''
            changed_places = torch.nonzero(task_label_ids, as_tuple=False)
            new_prob_matrix[changed_places,
                            0] = 1.0 - prob_matrix[changed_places, 0]

            loss = F.nll_loss(new_prob_matrix, label_ids.view(-1))

            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            iter_co += 1

            # if iter_co % len(train_dataloader) ==0:
            if iter_co % (len(train_dataloader) // 5) == 0:
                '''
                start evaluate on  dev set after this epoch
                '''
                # if n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
                #     model = torch.nn.DataParallel(model)
                model.eval()
                for m in model.modules():
                    if isinstance(m, torch.nn.BatchNorm2d):
                        m.track_running_stats = False
                # logger.info("***** Running evaluation *****")
                # logger.info("  Num examples = %d", len(valid_examples_MNLI))
                # logger.info("  Batch size = %d", args.eval_batch_size)

                dev_acc_sum = 0.0
                for idd, valid_dataloader in enumerate(valid_dataloader_list):
                    task_label = dev_task_label[idd]
                    eval_loss = 0
                    nb_eval_steps = 0
                    preds = []
                    gold_label_ids = []
                    # print('Evaluating...', task_label)
                    # for _, batch in enumerate(tqdm(valid_dataloader, desc=task_names[idd])):
                    for _, batch in enumerate(valid_dataloader):
                        batch = tuple(t.to(device) for t in batch)
                        input_ids, input_mask, segment_ids, label_ids, task_label_ids = batch
                        if task_label == 0:
                            gold_label_ids += list(
                                label_ids.detach().cpu().numpy())
                        else:
                            '''SciTail, RTE'''
                            task_label_ids_list = list(
                                task_label_ids.detach().cpu().numpy())
                            gold_label_batch_fake = list(
                                label_ids.detach().cpu().numpy())
                            for ex_id, label_id in enumerate(
                                    gold_label_batch_fake):
                                if task_label_ids_list[ex_id] == 0:
                                    gold_label_ids.append(label_id)  #0
                                else:
                                    gold_label_ids.append(1)  #1
                        with torch.no_grad():
                            logits = model(input_ids=input_ids,
                                           attention_mask=input_mask,
                                           token_type_ids=None,
                                           labels=None)
                        logits = logits[0]
                        if len(preds) == 0:
                            preds.append(logits.detach().cpu().numpy())
                        else:
                            preds[0] = np.append(preds[0],
                                                 logits.detach().cpu().numpy(),
                                                 axis=0)

                    preds = preds[0]
                    pred_probs = softmax(preds, axis=1)
                    pred_label_ids_3way = np.argmax(pred_probs, axis=1)
                    if task_label == 0:
                        '''3-way tasks MNLI, SNLI, ANLI'''
                        pred_label_ids = pred_label_ids_3way
                    else:
                        '''SciTail, RTE'''
                        pred_label_ids = []
                        for pred_label_i in pred_label_ids_3way:
                            if pred_label_i == 0:
                                pred_label_ids.append(0)
                            else:
                                pred_label_ids.append(1)
                    assert len(pred_label_ids) == len(gold_label_ids)
                    hit_co = 0
                    for k in range(len(pred_label_ids)):
                        if pred_label_ids[k] == gold_label_ids[k]:
                            hit_co += 1
                    test_acc = hit_co / len(gold_label_ids)
                    dev_acc_sum += test_acc
                    print(task_names[idd], ' dev acc:', test_acc)
                '''store the model, because we can test after a max_dev acc reached'''
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                store_transformers_models(
                    model_to_save, tokenizer,
                    '/export/home/Dataset/BERT_pretrained_mine/TrainedModelReminder/',
                    'RoBERTa_on_MNLI_SNLI_SciTail_RTE_ANLI_SpecialToken_Filter_1_epoch_'
                    + str(epoch_i) + '_acc_' + str(dev_acc_sum))
Exemplo n.º 22
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")

    parser.add_argument('--kshot',
                        type=int,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    train_examples, dev_examples, test_examples = load_FewRel_data(args.kshot)

    label_list = ["entailment", "non_entailment"]
    num_labels = len(label_list)
    print('num_labels:', num_labels, 'training size:', len(train_examples),
          'dev size:', len(dev_examples), 'test size:', len(test_examples))

    num_train_optimization_steps = None
    num_train_optimization_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps) * args.num_train_epochs
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    model = RobertaForSequenceClassification(num_labels)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0
    if args.do_train:
        train_dataloader = examples_to_features(train_examples,
                                                label_list,
                                                args,
                                                tokenizer,
                                                args.train_batch_size,
                                                "classification",
                                                dataloader_mode='random')
        dev_dataloader = examples_to_features(dev_examples,
                                              label_list,
                                              args,
                                              tokenizer,
                                              args.eval_batch_size,
                                              "classification",
                                              dataloader_mode='sequential')
        test_dataloader = examples_to_features(test_examples,
                                               label_list,
                                               args,
                                               tokenizer,
                                               args.eval_batch_size,
                                               "classification",
                                               dataloader_mode='sequential')

        iter_co = 0
        final_test_performance = 0.0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                _, input_ids, input_mask, segment_ids, label_ids = batch

                logits = model(input_ids, input_mask)
                loss_fct = CrossEntropyLoss()

                loss = loss_fct(logits.view(-1, num_labels),
                                label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                iter_co += 1
                # if iter_co %20==0:
                if iter_co % len(train_dataloader) == 0:
                    '''
                    start evaluate on dev set after this epoch
                    '''
                    model.eval()
                    dev_acc = evaluation(model,
                                         dev_dataloader,
                                         device,
                                         flag='Dev')
                    if dev_acc > max_dev_acc:
                        max_dev_acc = dev_acc
                        print('\n\t dev acc:', dev_acc, ' max_dev_acc:',
                              max_dev_acc, '\n')
                        test_acc = evaluation(model,
                                              test_dataloader,
                                              device,
                                              flag='Test')
                        if test_acc > max_test_acc:
                            max_test_acc = test_acc

                        final_test_performance = test_acc
                        print('\n\t test acc:', test_acc, ' max_test_acc:',
                              max_test_acc, '\n')
                    else:
                        print('\n\t dev acc:', dev_acc, ' max_dev_acc:',
                              max_dev_acc, '\n')

        print('final_test_performance:', final_test_performance)
Exemplo n.º 23
0
def main():
    parser = argparse.ArgumentParser()

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")

    parser.add_argument('--kshot',
                        type=int,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--target_train_batch_size",
                        default=2,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    scitail_path = '/export/home/Dataset/SciTailV1/tsv_format/'
    target_kshot_entail_examples, target_kshot_nonentail_examples = get_SciTail_as_train_k_shot(
        scitail_path + 'scitail_1.0_train.tsv', args.kshot,
        args.seed)  #train_pu_half_v1.txt
    target_dev_examples, target_test_examples = get_SciTail_dev_and_test(
        scitail_path + 'scitail_1.0_dev.tsv',
        scitail_path + 'scitail_1.0_test.tsv')

    system_seed = 42
    random.seed(system_seed)
    np.random.seed(system_seed)
    torch.manual_seed(system_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(system_seed)

    source_kshot_size = 10  # if args.kshot>10 else 10 if max(10, args.kshot)
    source_kshot_entail, source_kshot_neural, source_kshot_contra, source_remaining_examples = get_MNLI_train(
        '/export/home/Dataset/glue_data/MNLI/train.tsv', source_kshot_size)
    source_examples = source_kshot_entail + source_kshot_neural + source_kshot_contra + source_remaining_examples
    target_label_list = ["entails", "neutral"]
    source_label_list = ["entailment", "neutral", "contradiction"]
    source_num_labels = len(source_label_list)
    target_num_labels = len(target_label_list)
    print('training size:', len(source_examples), 'dev size:',
          len(target_dev_examples), 'test size:', len(target_test_examples))

    roberta_model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    roberta_model.load_state_dict(torch.load(
        '/export/home/Dataset/BERT_pretrained_mine/MNLI_pretrained/_acc_0.9040886899918633.pt'
    ),
                                  strict=False)
    roberta_model.to(device)
    roberta_model.eval()

    protonet = PrototypeNet(bert_hidden_dim)
    protonet.to(device)

    param_optimizer = list(protonet.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0

    retrieve_batch_size = 5

    source_kshot_entail_dataloader = examples_to_features(
        source_kshot_entail,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_kshot_neural_dataloader = examples_to_features(
        source_kshot_neural,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_kshot_contra_dataloader = examples_to_features(
        source_kshot_contra,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_remain_ex_dataloader = examples_to_features(
        source_remaining_examples,
        source_label_list,
        args,
        tokenizer,
        args.train_batch_size,
        "classification",
        dataloader_mode='random')

    target_kshot_entail_dataloader = examples_to_features(
        target_kshot_entail_examples,
        target_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    target_kshot_nonentail_dataloader = examples_to_features(
        target_kshot_nonentail_examples,
        target_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    target_dev_dataloader = examples_to_features(target_dev_examples,
                                                 target_label_list,
                                                 args,
                                                 tokenizer,
                                                 args.eval_batch_size,
                                                 "classification",
                                                 dataloader_mode='sequential')
    target_test_dataloader = examples_to_features(target_test_examples,
                                                  target_label_list,
                                                  args,
                                                  tokenizer,
                                                  args.eval_batch_size,
                                                  "classification",
                                                  dataloader_mode='sequential')
    '''
    retrieve rep for support examples in MNLI
    '''
    kshot_entail_reps = []
    for entail_batch in source_kshot_entail_dataloader:
        entail_batch = tuple(t.to(device) for t in entail_batch)
        input_ids, input_mask, segment_ids, label_ids = entail_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_entail, _ = roberta_model(input_ids, input_mask)
        kshot_entail_reps.append(last_hidden_entail)
    kshot_entail_rep = torch.mean(torch.cat(kshot_entail_reps, dim=0),
                                  dim=0,
                                  keepdim=True)
    kshot_neural_reps = []
    for neural_batch in source_kshot_neural_dataloader:
        neural_batch = tuple(t.to(device) for t in neural_batch)
        input_ids, input_mask, segment_ids, label_ids = neural_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_neural, _ = roberta_model(input_ids, input_mask)
        kshot_neural_reps.append(last_hidden_neural)
    kshot_neural_rep = torch.mean(torch.cat(kshot_neural_reps, dim=0),
                                  dim=0,
                                  keepdim=True)
    kshot_contra_reps = []
    for contra_batch in source_kshot_contra_dataloader:
        contra_batch = tuple(t.to(device) for t in contra_batch)
        input_ids, input_mask, segment_ids, label_ids = contra_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_contra, _ = roberta_model(input_ids, input_mask)
        kshot_contra_reps.append(last_hidden_contra)
    kshot_contra_rep = torch.mean(torch.cat(kshot_contra_reps, dim=0),
                                  dim=0,
                                  keepdim=True)

    source_class_prototype_reps = torch.cat(
        [kshot_entail_rep, kshot_neural_rep, kshot_contra_rep],
        dim=0)  #(3, hidden)
    '''first get representations for support examples in target'''
    kshot_entail_reps = []
    for entail_batch in target_kshot_entail_dataloader:
        entail_batch = tuple(t.to(device) for t in entail_batch)
        input_ids, input_mask, segment_ids, label_ids = entail_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_entail, _ = roberta_model(input_ids, input_mask)
        kshot_entail_reps.append(last_hidden_entail)
    all_kshot_entail_reps = torch.cat(kshot_entail_reps, dim=0)
    kshot_entail_rep = torch.mean(all_kshot_entail_reps, dim=0, keepdim=True)
    kshot_nonentail_reps = []
    for nonentail_batch in target_kshot_nonentail_dataloader:
        nonentail_batch = tuple(t.to(device) for t in nonentail_batch)
        input_ids, input_mask, segment_ids, label_ids = nonentail_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_nonentail, _ = roberta_model(input_ids, input_mask)
        kshot_nonentail_reps.append(last_hidden_nonentail)
    all_kshot_neural_reps = torch.cat(kshot_nonentail_reps, dim=0)
    kshot_nonentail_rep = torch.mean(all_kshot_neural_reps,
                                     dim=0,
                                     keepdim=True)
    target_class_prototype_reps = torch.cat(
        [kshot_entail_rep, kshot_nonentail_rep, kshot_nonentail_rep],
        dim=0)  #(3, hidden)

    class_prototype_reps = torch.cat(
        [source_class_prototype_reps, target_class_prototype_reps],
        dim=0)  #(6, hidden)
    '''starting to train'''
    iter_co = 0
    tr_loss = 0
    source_loss = 0
    target_loss = 0
    final_test_performance = 0.0
    for _ in trange(int(args.num_train_epochs), desc="Epoch"):

        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(
                tqdm(source_remain_ex_dataloader, desc="Iteration")):
            protonet.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, source_label_ids_batch = batch

            roberta_model.eval()
            with torch.no_grad():
                source_last_hidden_batch, _ = roberta_model(
                    input_ids, input_mask)
            '''forward to model'''
            target_batch_size = args.target_train_batch_size  #10*3
            target_batch_size_entail = target_batch_size  #random.randrange(5)+1
            target_batch_size_neural = target_batch_size  #random.randrange(5)+1

            selected_target_entail_rep = all_kshot_entail_reps[torch.randperm(
                all_kshot_entail_reps.shape[0])[:target_batch_size_entail]]
            selected_target_neural_rep = all_kshot_neural_reps[torch.randperm(
                all_kshot_neural_reps.shape[0])[:target_batch_size_neural]]
            target_last_hidden_batch = torch.cat(
                [selected_target_entail_rep, selected_target_neural_rep])

            last_hidden_batch = torch.cat(
                [source_last_hidden_batch, target_last_hidden_batch],
                dim=0)  #(train_batch_size+10*2)
            batch_logits = protonet(class_prototype_reps, last_hidden_batch)
            '''source side loss'''
            # loss_fct = CrossEntropyLoss(reduction='none')
            loss_fct = CrossEntropyLoss()
            source_loss_list = loss_fct(
                batch_logits[:source_last_hidden_batch.shape[0]].view(
                    -1, source_num_labels), source_label_ids_batch.view(-1))
            '''target side loss'''
            target_label_ids_batch = torch.tensor(
                [0] * selected_target_entail_rep.shape[0] +
                [1] * selected_target_neural_rep.shape[0],
                dtype=torch.long)
            target_batch_logits = batch_logits[-target_last_hidden_batch.
                                               shape[0]:]
            target_loss_list = loss_by_logits_and_2way_labels(
                target_batch_logits, target_label_ids_batch.view(-1), device)

            loss = source_loss_list + target_loss_list  #torch.mean(torch.cat([source_loss_list, target_loss_list]))
            source_loss += source_loss_list
            target_loss += target_loss_list
            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            global_step += 1
            iter_co += 1
            '''print loss'''
            # if iter_co %5==0:
            #     print('iter_co:', iter_co, ' mean loss', tr_loss/iter_co)
            #     print('source_loss_list:', source_loss/iter_co, ' target_loss_list: ', target_loss/iter_co)
            if iter_co % 1 == 0:
                # if iter_co % len(source_remain_ex_dataloader)==0:
                '''
                start evaluate on dev set after this epoch
                '''
                protonet.eval()

                for idd, dev_or_test_dataloader in enumerate(
                    [target_dev_dataloader, target_test_dataloader]):

                    eval_loss = 0
                    nb_eval_steps = 0
                    preds = []
                    gold_label_ids = []
                    # print('Evaluating...')
                    for input_ids, input_mask, segment_ids, label_ids in dev_or_test_dataloader:
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        label_ids = label_ids.to(device)
                        gold_label_ids += list(
                            label_ids.detach().cpu().numpy())
                        roberta_model.eval()
                        with torch.no_grad():
                            last_hidden_target_batch, logits_from_source = roberta_model(
                                input_ids, input_mask)

                        with torch.no_grad():
                            logits = protonet(class_prototype_reps,
                                              last_hidden_target_batch)
                        '''combine with logits from source domain'''
                        # print('logits:', logits)
                        # print('logits_from_source:', logits_from_source)
                        # weight = 0.9
                        # logits = weight*logits+(1.0-weight)*torch.sigmoid(logits_from_source)
                        if len(preds) == 0:
                            preds.append(logits.detach().cpu().numpy())
                        else:
                            preds[0] = np.append(preds[0],
                                                 logits.detach().cpu().numpy(),
                                                 axis=0)

                    preds = preds[0]

                    pred_probs = softmax(preds, axis=1)
                    pred_label_ids_3way = list(np.argmax(pred_probs, axis=1))
                    '''change from 3-way to 2-way'''
                    pred_label_ids = []
                    for pred_id in pred_label_ids_3way:
                        if pred_id != 0:
                            pred_label_ids.append(1)
                        else:
                            pred_label_ids.append(0)

                    gold_label_ids = gold_label_ids
                    assert len(pred_label_ids) == len(gold_label_ids)
                    hit_co = 0
                    for k in range(len(pred_label_ids)):
                        if pred_label_ids[k] == gold_label_ids[k]:
                            hit_co += 1
                    test_acc = hit_co / len(gold_label_ids)

                    if idd == 0:  # this is dev
                        if test_acc > max_dev_acc:
                            max_dev_acc = test_acc
                            print('\niter', iter_co, '\tdev acc:', test_acc,
                                  ' max_dev_acc:', max_dev_acc, '\n')

                        else:
                            print('\niter', iter_co, '\tdev acc:', test_acc,
                                  ' max_dev_acc:', max_dev_acc, '\n')
                            break
                    else:  # this is test
                        if test_acc > max_test_acc:
                            max_test_acc = test_acc

                        final_test_performance = test_acc
                        print('\niter', iter_co, '\ttest acc:', test_acc,
                              ' max_test_acc:', max_test_acc, '\n')
            # if iter_co == 500:#3000:
            #     break
    print('final_test_performance:', final_test_performance)
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")

    parser.add_argument('--kshot',
                        type=float,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    train_examples = processor.get_GAP_coreference(
        'gap-development.tsv', args.kshot)  #train_pu_half_v1.txt
    dev_examples = processor.get_GAP_coreference('gap-validation.tsv', 1.0)
    test_examples = processor.get_GAP_coreference('gap-test.tsv', 1.0)
    label_list = ["entailment", "not_entailment"]
    entity_label_list = ["A-coref", "B-coref"]
    # train_examples = get_data_hulu_fewshot('train', 5)
    # train_examples, dev_examples, test_examples, label_list = load_CLINC150_with_specific_domain_sequence(args.DomainName, args.kshot, augment=False)
    num_labels = len(label_list)
    print('num_labels:', num_labels, 'training size:', len(train_examples),
          'dev size:', len(dev_examples), 'test size:', len(test_examples))

    num_train_optimization_steps = None
    num_train_optimization_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps) * args.num_train_epochs
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    model = RobertaForSequenceClassification(num_labels)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0
    max_dev_threshold = 0.0
    if args.do_train:
        train_dataloader = examples_to_features(train_examples,
                                                label_list,
                                                entity_label_list,
                                                args,
                                                tokenizer,
                                                args.train_batch_size,
                                                "classification",
                                                dataloader_mode='random')
        dev_dataloader = examples_to_features(dev_examples,
                                              label_list,
                                              entity_label_list,
                                              args,
                                              tokenizer,
                                              args.eval_batch_size,
                                              "classification",
                                              dataloader_mode='sequential')
        test_dataloader = examples_to_features(test_examples,
                                               label_list,
                                               entity_label_list,
                                               args,
                                               tokenizer,
                                               args.eval_batch_size,
                                               "classification",
                                               dataloader_mode='sequential')

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)

        iter_co = 0
        final_test_performance = 0.0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_example_ids, input_ids, input_mask, span_a_mask, span_b_mask, segment_ids, label_ids, entity_label_ids = batch

                logits = model(input_ids, input_mask, span_a_mask, span_b_mask)
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, num_labels),
                                label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                iter_co += 1
                # if iter_co %100==0:
                #     print('iter_co:', iter_co, ' mean loss:', tr_loss/iter_co)
                if iter_co % len(train_dataloader) == 0:

                    model.eval()
                    '''
                     dev set after this epoch
                    '''

                    logger.info("***** Running dev *****")
                    logger.info("  Num examples = %d", len(dev_examples))

                    eval_loss = 0
                    nb_eval_steps = 0
                    preds = []
                    gold_label_ids = []
                    example_id_list = []
                    for _, batch in enumerate(tqdm(dev_dataloader,
                                                   desc="dev")):
                        input_indices, input_ids, input_mask, span_a_mask, span_b_mask, segment_ids, _, label_ids = batch
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        span_a_mask = span_a_mask.to(device)
                        span_b_mask = span_b_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        label_ids = label_ids.to(device)
                        example_ids = list(input_indices.numpy())
                        example_id_list += example_ids
                        gold_label_ids += list(
                            label_ids.detach().cpu().numpy())

                        with torch.no_grad():
                            logits = model(input_ids, input_mask, span_a_mask,
                                           span_b_mask)
                        if len(preds) == 0:
                            preds.append(logits.detach().cpu().numpy())
                        else:
                            preds[0] = np.append(preds[0],
                                                 logits.detach().cpu().numpy(),
                                                 axis=0)

                    preds = preds[0]

                    pred_probs = softmax(preds, axis=1)
                    pred_label_ids_3way = list(np.argmax(pred_probs, axis=1))
                    pred_prob_entail = list(pred_probs[:, 0])

                    assert len(example_id_list) == len(pred_prob_entail)
                    assert len(example_id_list) == len(gold_label_ids)
                    assert len(example_id_list) == len(pred_label_ids_3way)

                    best_current_dev_acc = 0.0
                    best_current_threshold = -10.0
                    for threshold in np.arange(0.99, 0.0, -0.01):
                        # print('example_id_list:', example_id_list)
                        eval_output_list = build_GAP_output_format(
                            example_id_list,
                            gold_label_ids,
                            pred_prob_entail,
                            pred_label_ids_3way,
                            threshold,
                            dev_or_test='validation')
                        dev_acc = run_scorer(
                            '/export/home/Dataset/gap_coreference/gap-validation.tsv',
                            eval_output_list)
                        # print('dev_acc:', dev_acc)
                        # exit(0)
                        if dev_acc > best_current_dev_acc:
                            best_current_dev_acc = dev_acc
                            best_current_threshold = threshold
                    print('best_current_dev_threshold:',
                          best_current_threshold, 'best_current_dev_acc:',
                          best_current_dev_acc)

                    if best_current_dev_acc > max_dev_acc:
                        max_dev_acc = best_current_dev_acc
                        max_dev_threshold = best_current_threshold
                        '''eval on test set'''
                        logger.info("***** Running test *****")
                        logger.info("  Num examples = %d", len(test_examples))

                        eval_loss = 0
                        nb_eval_steps = 0
                        preds = []
                        gold_label_ids = []
                        example_id_list = []
                        for _, batch in enumerate(
                                tqdm(test_dataloader, desc="test")):
                            input_indices, input_ids, input_mask, span_a_mask, span_b_mask, segment_ids, _, label_ids = batch
                            input_ids = input_ids.to(device)
                            input_mask = input_mask.to(device)
                            span_a_mask = span_a_mask.to(device)
                            span_b_mask = span_b_mask.to(device)
                            segment_ids = segment_ids.to(device)
                            label_ids = label_ids.to(device)
                            example_ids = list(input_indices.numpy())
                            example_id_list += example_ids
                            gold_label_ids += list(
                                label_ids.detach().cpu().numpy())

                            with torch.no_grad():
                                logits = model(input_ids, input_mask,
                                               span_a_mask, span_b_mask)
                            if len(preds) == 0:
                                preds.append(logits.detach().cpu().numpy())
                            else:
                                preds[0] = np.append(
                                    preds[0],
                                    logits.detach().cpu().numpy(),
                                    axis=0)

                        preds = preds[0]

                        pred_probs = softmax(preds, axis=1)
                        pred_label_ids_3way = list(
                            np.argmax(pred_probs, axis=1))
                        pred_prob_entail = list(pred_probs[:, 0])

                        assert len(example_id_list) == len(pred_prob_entail)
                        assert len(example_id_list) == len(gold_label_ids)
                        assert len(example_id_list) == len(pred_label_ids_3way)

                        threshold = max_dev_threshold
                        eval_output_list = build_GAP_output_format(
                            example_id_list,
                            gold_label_ids,
                            pred_prob_entail,
                            pred_label_ids_3way,
                            threshold,
                            dev_or_test='test')

                        test_acc = run_scorer(
                            '/export/home/Dataset/gap_coreference/gap-test.tsv',
                            eval_output_list)
                        if test_acc > max_test_acc:
                            max_test_acc = test_acc
                        print('current_test_acc:', test_acc, ' max_test_acc:',
                              max_test_acc)
                        final_test_performance = test_acc
        print('final_test_performance:', final_test_performance)
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--DomainName",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_data_aug",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--meta_epochs',
                        type=int,
                        default=10,
                        help="random seed for initialization")
    parser.add_argument('--kshot',
                        type=int,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    args = parser.parse_args()


    processors = {
        "rte": RteProcessor
    }

    output_modes = {
        "rte": "classification"
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")


    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    # label_list = processor.get_labels() #["entailment", "neutral", "contradiction"]
    # label_list = ['How_do_I_create_a_profile_v4', 'Profile_Switch_v4', 'Deactivate_Active_Devices_v4', 'Ads_on_Hulu_v4', 'Watching_Hulu_with_Live_TV_v4', 'Hulu_Costs_and_Commitments_v4', 'offline_downloads_v4', 'womens_world_cup_v5', 'forgot_username_v4', 'confirm_account_cancellation_v4', 'Devices_to_Watch_HBO_on_v4', 'remove_add_on_v4', 'Internet_Speed_for_HD_and_4K_v4', 'roku_related_questions_v4', 'amazon_related_questions_v4', 'Clear_Browser_Cache_v4', 'ads_on_ad_free_plan_v4', 'inappropriate_ads_v4', 'itunes_related_questions_v4', 'Internet_Speed_Recommendations_v4', 'NBA_Basketball_v5', 'unexpected_charges_v4', 'change_billing_date_v4', 'NFL_on_Hulu_v5', 'How_to_delete_a_profile_v4', 'Devices_to_Watch_Hulu_on_v4', 'Manage_your_Hulu_subscription_v4', 'cancel_hulu_account_v4', 'disney_bundle_v4', 'payment_issues_v4', 'home_network_location_v4', 'Main_Menu_v4', 'Resetting_Hulu_Password_v4', 'Update_Payment_v4', 'I_need_general_troubleshooting_help_v4', 'What_is_Hulu_v4', 'sprint_related_questions_v4', 'Log_into_TV_with_activation_code_v4', 'Game_of_Thrones_v4', 'video_playback_issues_v4', 'How_to_edit_a_profile_v4', 'Watchlist_Remove_Video_v4', 'spotify_related_questions_v4', 'Deactivate_Login_Sessions_v4', 'Transfer_to_Agent_v4', 'Use_Hulu_Internationally_v4']

    meta_train_examples, meta_dev_examples, meta_test_examples, meta_label_list = load_CLINC150_without_specific_domain(args.DomainName)
    train_examples, dev_examples, eval_examples, finetune_label_list = load_CLINC150_with_specific_domain_sequence(args.DomainName, args.kshot, augment=args.do_data_aug)
    # oos_dev_examples, oos_test_examples = load_OOS()
    # dev_examples+=oos_dev_examples
    # eval_examples+=oos_test_examples

    eval_label_list = finetune_label_list#+['oos']
    label_list=finetune_label_list+meta_label_list#+['oos']
    assert len(label_list) ==  15*10
    num_labels = len(label_list)
    assert num_labels == 15*10


    model = RobertaForSequenceClassification(num_labels)


    tokenizer = RobertaTokenizer.from_pretrained(pretrain_model_dir, do_lower_case=args.do_lower_case)
    # tokenizer = BertTokenizer.from_pretrained(pretrain_model_dir, do_lower_case=args.do_lower_case)

    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    optimizer = AdamW(optimizer_grouped_parameters,
                             lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0
    if args.do_train:
        meta_train_features = convert_examples_to_features(
            meta_train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
            cls_token_at_end=False,#bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,#2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=True,#bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=False,#bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=0)#4 if args.model_type in ['xlnet'] else 0,)


        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
            cls_token_at_end=False,#bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,#2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=True,#bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=False,#bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=0)#4 if args.model_type in ['xlnet'] else 0,)

        '''load dev set'''
        # dev_examples = processor.get_RTE_as_dev('/export/home/Dataset/glue_data/RTE/dev.tsv')
        # dev_examples = get_data_hulu('dev')
        dev_features = convert_examples_to_features(
            dev_examples, eval_label_list, args.max_seq_length, tokenizer, output_mode,
            cls_token_at_end=False,#bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,#2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=True,#bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=False,#bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=0)#4 if args.model_type in ['xlnet'] else 0,)

        dev_all_input_ids = torch.tensor([f.input_ids for f in dev_features], dtype=torch.long)
        dev_all_input_mask = torch.tensor([f.input_mask for f in dev_features], dtype=torch.long)
        dev_all_segment_ids = torch.tensor([f.segment_ids for f in dev_features], dtype=torch.long)
        dev_all_label_ids = torch.tensor([f.label_id for f in dev_features], dtype=torch.long)

        dev_data = TensorDataset(dev_all_input_ids, dev_all_input_mask, dev_all_segment_ids, dev_all_label_ids)
        dev_sampler = SequentialSampler(dev_data)
        dev_dataloader = DataLoader(dev_data, sampler=dev_sampler, batch_size=args.eval_batch_size)


        '''load test set'''
        # eval_examples = processor.get_RTE_as_test('/export/home/Dataset/RTE/test_RTE_1235.txt')
        # eval_examples = get_data_hulu('test')
        eval_features = convert_examples_to_features(
            eval_examples, eval_label_list, args.max_seq_length, tokenizer, output_mode,
            cls_token_at_end=False,#bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,#2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=True,#bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=False,#bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=0)#4 if args.model_type in ['xlnet'] else 0,)

        eval_all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        eval_all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        eval_all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        eval_all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)

        eval_data = TensorDataset(eval_all_input_ids, eval_all_input_mask, eval_all_segment_ids, eval_all_label_ids)
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        all_input_ids = torch.tensor([f.input_ids for f in meta_train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in meta_train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in meta_train_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in meta_train_features], dtype=torch.long)

        meta_train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        meta_train_sampler = RandomSampler(meta_train_data)
        meta_train_dataloader = DataLoader(meta_train_data, sampler=meta_train_sampler, batch_size=args.train_batch_size*10)


        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
        '''support labeled examples in order, group in kshot size'''
        support_sampler = SequentialSampler(train_data)
        support_dataloader = DataLoader(train_data, sampler=support_sampler, batch_size=args.kshot)


        iter_co = 0
        max_dev_test = [0,0]
        fine_max_dev = False
        '''first train on meta_train tasks'''
        for meta_epoch_i in trange(args.meta_epochs, desc="metaEpoch"):
            for step, batch in enumerate(tqdm(meta_train_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                logits,_,_ = model(input_ids, input_mask, None, labels=None)
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            '''get class representation after each epoch of pretraining'''
            model.eval()
            last_reps_list = []
            for input_ids, input_mask, segment_ids, label_ids in support_dataloader:
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)
                # gold_label_ids+=list(label_ids.detach().cpu().numpy())

                with torch.no_grad():
                    logits, last_reps, bias = model(input_ids, input_mask, None, labels=None)
                last_reps_list.append(last_reps.mean(dim=0, keepdim=True)) #(1, 1024)
            class_reps_pretraining = torch.cat(last_reps_list, dim=0) #(15, 1024)

            '''
            start evaluate on dev set after this epoch
            '''
            for idd, dev_or_test_dataloader in enumerate([dev_dataloader, eval_dataloader]):
                if idd == 0:
                    logger.info("***** Running dev *****")
                    logger.info("  Num examples = %d", len(dev_examples))
                else:
                    logger.info("***** Running test *****")
                    logger.info("  Num examples = %d", len(eval_examples))
                # logger.info("  Batch size = %d", args.eval_batch_size)

                eval_loss = 0
                nb_eval_steps = 0
                preds = []
                gold_label_ids = []
                # print('Evaluating...')
                for input_ids, input_mask, segment_ids, label_ids in dev_or_test_dataloader:
                    input_ids = input_ids.to(device)
                    input_mask = input_mask.to(device)
                    segment_ids = segment_ids.to(device)
                    label_ids = label_ids.to(device)
                    gold_label_ids+=list(label_ids.detach().cpu().numpy())

                    with torch.no_grad():
                        logits_LR, reps_batch, _ = model(input_ids, input_mask, None, labels=None)
                    # logits = logits[0]

                    '''pretraining logits'''
                    raw_similarity_scores = torch.mm(reps_batch,torch.transpose(class_reps_pretraining, 0,1)) #(batch, 15)
                    # print('raw_similarity_scores shaoe:', raw_similarity_scores.shape)
                    # print('bias_finetune:', bias_finetune.shape)
                    biased_similarity_scores = raw_similarity_scores#+bias_finetune.view(-1, raw_similarity_scores.shape[1])
                    logits_pretrain = torch.max(biased_similarity_scores.view(args.eval_batch_size, -1, len(finetune_label_list)), dim=1)[0] #(batch, #class)
                    '''finetune logits'''
                    # raw_similarity_scores = torch.mm(reps_batch,torch.transpose(class_reps_finetune, 0,1)) #(batch, 15*history)
                    # biased_similarity_scores = raw_similarity_scores+bias_finetune.view(-1, raw_similarity_scores.shape[1])
                    # logits_finetune = torch.max(biased_similarity_scores.view(args.eval_batch_size, -1, len(finetune_label_list)), dim=1)[0] #(batch, #class)

                    logits = logits_pretrain#+logits_finetune
                    # logits = (1-0.9)*logits+0.9*logits_LR

                    if len(preds) == 0:
                        preds.append(logits.detach().cpu().numpy())
                    else:
                        preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)

                # eval_loss = eval_loss / nb_eval_steps
                preds = preds[0]
                pred_probs = softmax(preds,axis=1)
                pred_label_ids = list(np.argmax(pred_probs, axis=1))
                gold_label_ids = gold_label_ids
                assert len(pred_label_ids) == len(gold_label_ids)
                hit_co = 0

                for k in range(len(pred_label_ids)):
                    if pred_label_ids[k] == gold_label_ids[k]:
                        hit_co +=1
                test_acc = hit_co/len(gold_label_ids)

                if idd == 0: # this is dev
                    if test_acc > max_dev_acc:
                        max_dev_acc = test_acc
                        print('\ndev acc:', test_acc, ' max_dev_acc:', max_dev_acc, '\n')
                        fine_max_dev=True
                        max_dev_test[0] = round(max_dev_acc*100, 2)
                    else:
                        print('\ndev acc:', test_acc, ' max_dev_acc:', max_dev_acc, '\n')
                        break
                else: # this is test
                    if test_acc > max_test_acc:
                        max_test_acc = test_acc
                    if fine_max_dev:
                        max_dev_test[1] = round(test_acc*100,2)
                        fine_max_dev = False
                    print('\ttest acc:', test_acc, ' max_test_acc:', max_test_acc, '\n')


        print('final:', str(max_dev_test[0])+'/'+str(max_dev_test[1]), '\n')
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")

    parser.add_argument('--kshot',
                        type=int,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--neighbor_size_limit',
                        type=int,
                        default=500,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    train_examples = processor.get_RTE_as_train_k_shot(
        '/export/home/Dataset/glue_data/RTE/train.tsv',
        args.kshot)  #train_pu_half_v1.txt
    train_examples_MNLI = processor.get_MNLI_train(
        '/export/home/Dataset/glue_data/MNLI/train.tsv')

    source_example_2_gramset = {}
    for mnli_ex in train_examples_MNLI:
        source_example_2_gramset[mnli_ex] = gram_set(mnli_ex)
    print('MNLI gramset build over')
    # neighbor_size_limit = 500
    train_examples_neighbors = retrieve_neighbors_source_given_kshot_target(
        train_examples, source_example_2_gramset, args.neighbor_size_limit)
    print('neighbor size:', len(train_examples_neighbors))
    # train_examples_neighbors_2way = []
    # for neighbor_ex in train_examples_neighbors:
    #     if neighbor_ex.label !='entailment':
    #         neighbor_ex.label = 'not_entailment'
    #     train_examples_neighbors_2way.append(neighbor_ex)

    dev_examples = processor.get_RTE_as_dev(
        '/export/home/Dataset/glue_data/RTE/dev.tsv')
    test_examples = processor.get_RTE_as_test(
        '/export/home/Dataset/RTE/test_RTE_1235.txt')
    label_list = ["entailment", "not_entailment"]
    mnli_label_list = ["entailment", "neutral", "contradiction"]
    # train_examples, dev_examples, test_examples, label_list = load_CLINC150_with_specific_domain_sequence(args.DomainName, args.kshot, augment=False)
    num_labels = len(label_list)
    print('num_labels:', num_labels, 'training size:', len(train_examples),
          'neighbor size:', len(train_examples_neighbors), 'dev size:',
          len(dev_examples), 'test size:', len(test_examples))

    num_train_optimization_steps = None
    num_train_optimization_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps) * args.num_train_epochs
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.load_state_dict(
        torch.load(
            '/export/home/Dataset/BERT_pretrained_mine/MNLI_pretrained/_acc_0.9040886899918633.pt'
        ))
    # model.load_state_dict(torch.load('/export/home/Dataset/BERT_pretrained_mine/MNLI_biased_pretrained/RTE.10shot.seed.42.pt'))
    model.to(device)

    target_model = RobertaForSequenceClassification_TargetClassifier(
        args.kshot * num_labels, 3)
    target_model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=5e-7)

    param_optimizer_target = list(target_model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters_target = [{
        'params': [
            p for n, p in param_optimizer_target
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.01
    }, {
        'params': [
            p for n, p in param_optimizer_target
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]

    optimizer_target = AdamW(optimizer_grouped_parameters_target,
                             lr=args.learning_rate)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0
    if args.do_train:
        train_dataloader = examples_to_features(train_examples,
                                                label_list,
                                                args,
                                                tokenizer,
                                                args.train_batch_size,
                                                "classification",
                                                dataloader_mode='random')
        train_neighbors_dataloader = examples_to_features(
            train_examples_neighbors,
            mnli_label_list,
            args,
            tokenizer,
            5,
            "classification",
            dataloader_mode='random')
        dev_dataloader = examples_to_features(dev_examples,
                                              label_list,
                                              args,
                                              tokenizer,
                                              args.eval_batch_size,
                                              "classification",
                                              dataloader_mode='sequential')
        test_dataloader = examples_to_features(test_examples,
                                               label_list,
                                               args,
                                               tokenizer,
                                               args.eval_batch_size,
                                               "classification",
                                               dataloader_mode='sequential')
        train_mnli_dataloader = examples_to_features(train_examples_MNLI,
                                                     mnli_label_list,
                                                     args,
                                                     tokenizer,
                                                     32,
                                                     "classification",
                                                     dataloader_mode='random')
        '''first pretrain on neighbors'''
        iter_co = 0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_neighbors_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                logits, _ = model(input_ids, input_mask)
                loss_fct = CrossEntropyLoss()

                loss = loss_fct(logits.view(-1, len(mnli_label_list)),
                                label_ids.view(-1))
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                iter_co += 1
            '''
            start evaluate on dev set after this epoch
            '''
            model.eval()

            eval_loss = 0
            nb_eval_steps = 0
            preds = []
            gold_label_ids = []
            # print('Evaluating...')
            for input_ids, input_mask, segment_ids, label_ids in dev_dataloader:
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)
                gold_label_ids += list(label_ids.detach().cpu().numpy())

                with torch.no_grad():
                    logits, _ = model(input_ids, input_mask)
                if len(preds) == 0:
                    preds.append(logits.detach().cpu().numpy())
                else:
                    preds[0] = np.append(preds[0],
                                         logits.detach().cpu().numpy(),
                                         axis=0)

            preds = preds[0]

            pred_probs = softmax(preds, axis=1)
            pred_label_ids_3way = list(np.argmax(pred_probs, axis=1))
            '''change from 3-way to 2-way'''
            pred_label_ids = []
            for pred_id in pred_label_ids_3way:
                if pred_id != 0:
                    pred_label_ids.append(1)
                else:
                    pred_label_ids.append(0)

            gold_label_ids = gold_label_ids
            assert len(pred_label_ids) == len(gold_label_ids)
            hit_co = 0
            for k in range(len(pred_label_ids)):
                if pred_label_ids[k] == gold_label_ids[k]:
                    hit_co += 1
            test_acc = hit_co / len(gold_label_ids)

            if test_acc > max_dev_acc:
                max_dev_acc = test_acc
                print('\ndev acc:', test_acc, ' max_dev_acc:', max_dev_acc,
                      '\n')
                '''store the model, because we can test after a max_dev acc reached'''
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                store_transformers_models(
                    model_to_save, tokenizer,
                    '/export/home/Dataset/BERT_pretrained_mine/MNLI_biased_pretrained',
                    'dev_v2_seed_' + str(args.seed) + '_acc_' +
                    str(max_dev_acc) + '.pt')
            else:
                print('\ndev acc:', test_acc, ' max_dev_acc:', max_dev_acc,
                      '\n')
        '''use MNLI to pretrain the target classifier'''
        model.load_state_dict(
            torch.load(
                '/export/home/Dataset/BERT_pretrained_mine/MNLI_biased_pretrained/'
                + 'dev_v2_seed_' + str(args.seed) + '_acc_' +
                str(max_dev_acc) + '.pt'))
        for _ in trange(3, desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_mnli_dataloader, desc="Iteration")):

                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                '''first get the rep'''
                model.eval()
                with torch.no_grad():
                    logits, last_hidden = model(input_ids, input_mask)
                prob_of_entail = F.log_softmax(logits.view(-1, 3),
                                               dim=1)[:, :1]  #(batch, 1)

                target_model.train()
                target_logits = target_model(last_hidden, prob_of_entail)
                loss_fct = CrossEntropyLoss()

                loss = loss_fct(target_logits.view(-1, len(mnli_label_list)),
                                label_ids.view(-1))
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                optimizer_target.step()
                optimizer_target.zero_grad()
        '''fine-tune on kshot'''

        # model.load_state_dict(torch.load('/export/home/Dataset/BERT_pretrained_mine/MNLI_biased_pretrained/'+'dev_seed_'+str(args.seed)+'_acc_'+str(max_dev_acc)+'.pt'))
        iter_co = 0
        max_dev_acc = 0.0
        final_test_performance = 0.0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                # model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                '''first get the rep'''
                model.eval()
                with torch.no_grad():
                    logits, last_hidden = model(input_ids, input_mask)
                prob_of_entail = F.log_softmax(logits.view(-1, 3),
                                               dim=1)[:, :1]  #(batch, 1)

                target_model.train()
                target_logits = target_model(last_hidden, prob_of_entail)

                prob_matrix = F.log_softmax(target_logits.view(-1, 3), dim=1)
                '''this step *1.0 is very important, otherwise bug'''
                new_prob_matrix = prob_matrix * 1.0
                '''change the entail prob to p or 1-p'''
                changed_places = torch.nonzero(label_ids.view(-1),
                                               as_tuple=False)
                new_prob_matrix[changed_places,
                                0] = 1.0 - prob_matrix[changed_places, 0]

                loss = F.nll_loss(
                    new_prob_matrix,
                    torch.zeros_like(label_ids).to(device).view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                optimizer_target.step()
                optimizer_target.zero_grad()
                global_step += 1
                iter_co += 1
                # if iter_co %20==0:
                if iter_co % len(train_dataloader) == 0:
                    '''
                    start evaluate on dev set after this epoch
                    '''
                    model.eval()
                    target_model.eval()

                    for idd, dev_or_test_dataloader in enumerate(
                        [dev_dataloader, test_dataloader]):

                        if idd == 0:
                            logger.info("***** Running dev *****")
                            logger.info("  Num examples = %d",
                                        len(dev_examples))
                        else:
                            logger.info("***** Running test *****")
                            logger.info("  Num examples = %d",
                                        len(test_examples))
                        # logger.info("  Batch size = %d", args.eval_batch_size)

                        eval_loss = 0
                        nb_eval_steps = 0
                        preds = []
                        gold_label_ids = []
                        # print('Evaluating...')
                        for input_ids, input_mask, segment_ids, label_ids in dev_or_test_dataloader:
                            input_ids = input_ids.to(device)
                            input_mask = input_mask.to(device)
                            segment_ids = segment_ids.to(device)
                            label_ids = label_ids.to(device)
                            gold_label_ids += list(
                                label_ids.detach().cpu().numpy())

                            with torch.no_grad():
                                source_logits, last_hidden = model(
                                    input_ids, input_mask)
                            prob_of_entail = F.log_softmax(
                                source_logits.view(-1, 3),
                                dim=1)[:, :1]  #(batch, 1)
                            with torch.no_grad():
                                logits = target_model(last_hidden,
                                                      prob_of_entail)

                            if len(preds) == 0:
                                preds.append(logits.detach().cpu().numpy())
                            else:
                                preds[0] = np.append(
                                    preds[0],
                                    logits.detach().cpu().numpy(),
                                    axis=0)

                        preds = preds[0]

                        pred_probs = softmax(preds, axis=1)
                        pred_label_ids_3way = list(
                            np.argmax(pred_probs, axis=1))
                        '''change from 3-way to 2-way'''
                        pred_label_ids = []
                        for pred_id in pred_label_ids_3way:
                            if pred_id != 0:
                                pred_label_ids.append(1)
                            else:
                                pred_label_ids.append(0)

                        gold_label_ids = gold_label_ids
                        assert len(pred_label_ids) == len(gold_label_ids)
                        hit_co = 0
                        for k in range(len(pred_label_ids)):
                            if pred_label_ids[k] == gold_label_ids[k]:
                                hit_co += 1
                        test_acc = hit_co / len(gold_label_ids)

                        if idd == 0:  # this is dev
                            if test_acc > max_dev_acc:
                                max_dev_acc = test_acc
                                print('\ndev acc:', test_acc, ' max_dev_acc:',
                                      max_dev_acc, '\n')

                            else:
                                print('\ndev acc:', test_acc, ' max_dev_acc:',
                                      max_dev_acc, '\n')
                                break
                        else:  # this is test
                            if test_acc > max_test_acc:
                                max_test_acc = test_acc

                            final_test_performance = test_acc
                            print('\ntest acc:', test_acc, ' max_test_acc:',
                                  max_test_acc, '\n')
        print('final_test_performance:', final_test_performance)
Exemplo n.º 27
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")

    parser.add_argument('--kshot',
                        type=float,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--use_mixup",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--beta_sampling_times',
                        type=int,
                        default=10,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    train_examples, dev_examples, test_examples, label_list = processor.load_FewRel_data(
        args.kshot)

    num_labels = len(label_list)
    print('num_labels:', num_labels, 'training size:', len(train_examples),
          'dev size:', len(dev_examples), 'test size:', len(test_examples))

    model = RobertaForSequenceClassification(num_labels)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)
        '''load dev set'''
        dev_features = convert_examples_to_features(
            dev_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        dev_all_input_ids = torch.tensor([f.input_ids for f in dev_features],
                                         dtype=torch.long)
        dev_all_input_mask = torch.tensor([f.input_mask for f in dev_features],
                                          dtype=torch.long)
        dev_all_segment_ids = torch.tensor(
            [f.segment_ids for f in dev_features], dtype=torch.long)
        dev_all_span_a_mask = torch.tensor(
            [f.span_a_mask for f in dev_features], dtype=torch.float)
        dev_all_span_b_mask = torch.tensor(
            [f.span_b_mask for f in dev_features], dtype=torch.float)

        dev_all_label_ids = torch.tensor([f.label_id for f in dev_features],
                                         dtype=torch.long)

        dev_data = TensorDataset(dev_all_input_ids, dev_all_input_mask,
                                 dev_all_segment_ids, dev_all_span_a_mask,
                                 dev_all_span_b_mask, dev_all_label_ids)
        dev_sampler = SequentialSampler(dev_data)
        dev_dataloader = DataLoader(dev_data,
                                    sampler=dev_sampler,
                                    batch_size=args.eval_batch_size)
        '''load test set'''
        test_features = convert_examples_to_features(
            test_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        eval_all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                          dtype=torch.long)
        eval_all_input_mask = torch.tensor(
            [f.input_mask for f in test_features], dtype=torch.long)
        eval_all_segment_ids = torch.tensor(
            [f.segment_ids for f in test_features], dtype=torch.long)
        eval_all_span_a_mask = torch.tensor(
            [f.span_a_mask for f in test_features], dtype=torch.float)
        eval_all_span_b_mask = torch.tensor(
            [f.span_b_mask for f in test_features], dtype=torch.float)
        # eval_all_pair_ids = [f.pair_id for f in test_features]
        eval_all_label_ids = torch.tensor([f.label_id for f in test_features],
                                          dtype=torch.long)

        eval_data = TensorDataset(eval_all_input_ids, eval_all_input_mask,
                                  eval_all_segment_ids, eval_all_span_a_mask,
                                  eval_all_span_b_mask, eval_all_label_ids)
        eval_sampler = SequentialSampler(eval_data)
        test_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        # logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_span_a_mask = torch.tensor([f.span_a_mask for f in train_features],
                                       dtype=torch.float)
        all_span_b_mask = torch.tensor([f.span_b_mask for f in train_features],
                                       dtype=torch.float)

        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_span_a_mask,
                                   all_span_b_mask, all_label_ids)
        train_sampler = RandomSampler(train_data)

        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        iter_co = 0
        final_test_performance = 0.0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, span_a_mask, span_b_mask, label_ids = batch

                #input_ids, input_mask, span_a_mask, span_b_mask
                logits = model(input_ids, input_mask, span_a_mask, span_b_mask)
                loss_fct = CrossEntropyLoss()

                loss = loss_fct(logits.view(-1, num_labels),
                                label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                iter_co += 1
                # if iter_co %20==0:
                if iter_co % len(train_dataloader) == 0:
                    # if iter_co % (len(train_dataloader)//2)==0:
                    '''
                    start evaluate on dev set after this epoch
                    '''
                    model.eval()

                    for idd, dev_or_test_dataloader in enumerate(
                        [dev_dataloader, test_dataloader]):

                        if idd == 0:
                            logger.info("***** Running dev *****")
                            logger.info("  Num examples = %d",
                                        len(dev_features))
                        else:
                            logger.info("***** Running test *****")
                            logger.info("  Num examples = %d",
                                        len(test_features))
                        # logger.info("  Batch size = %d", args.eval_batch_size)

                        eval_loss = 0
                        nb_eval_steps = 0
                        preds = []
                        gold_label_ids = []
                        # print('Evaluating...')
                        for input_ids, input_mask, segment_ids, span_a_mask, span_b_mask, label_ids in dev_or_test_dataloader:
                            input_ids = input_ids.to(device)
                            input_mask = input_mask.to(device)
                            segment_ids = segment_ids.to(device)
                            span_a_mask = span_a_mask.to(device)
                            span_b_mask = span_b_mask.to(device)
                            label_ids = label_ids.to(device)
                            gold_label_ids += list(
                                label_ids.detach().cpu().numpy())

                            with torch.no_grad():
                                logits = model(input_ids, input_mask,
                                               span_a_mask, span_b_mask)
                            if len(preds) == 0:
                                preds.append(logits.detach().cpu().numpy())
                            else:
                                preds[0] = np.append(
                                    preds[0],
                                    logits.detach().cpu().numpy(),
                                    axis=0)

                        preds = preds[0]

                        pred_probs = softmax(preds, axis=1)
                        pred_label_ids = list(np.argmax(pred_probs, axis=1))

                        assert len(pred_label_ids) == len(gold_label_ids)
                        hit_co = 0
                        for k in range(len(pred_label_ids)):
                            if pred_label_ids[k] == gold_label_ids[k]:
                                hit_co += 1
                        test_acc = hit_co / len(gold_label_ids)
                        f1 = test_acc

                        if idd == 0:  # this is dev
                            if f1 > max_dev_acc:
                                max_dev_acc = f1
                                print('\ndev acc :', f1, ' max_dev_acc:',
                                      max_dev_acc, '\n')
                                # '''store the model, because we can test after a max_dev acc reached'''
                                # model_to_save = (
                                #     model.module if hasattr(model, "module") else model
                                # )  # Take care of distributed/parallel training
                                # store_transformers_models(model_to_save, tokenizer, '/export/home/Dataset/BERT_pretrained_mine/event_2_nli', 'mnli_mypretrained_f1_'+str(max_dev_acc)+'.pt')

                            else:
                                print('\ndev acc :', f1, ' max_dev_acc:',
                                      max_dev_acc, '\n')
                                break
                        else:  # this is test
                            if f1 > max_test_acc:
                                max_test_acc = f1
                            final_test_performance = f1
                            print('\ntest acc:', f1, ' max_test_acc:',
                                  max_test_acc, '\n')
        print('final_test_f1:', final_test_performance)
def LoadDatasets(args, task_cfg, ids, split="trainval"):

    if "roberta" in args.bert_model:
        tokenizer = RobertaTokenizer.from_pretrained(
            args.bert_model, do_lower_case=args.do_lower_case)
    else:
        tokenizer = BertTokenizer.from_pretrained(
            args.bert_model, do_lower_case=args.do_lower_case)

    task_feature_reader1 = {}
    task_feature_reader2 = {}
    for i, task_id in enumerate(ids):
        task = "TASK" + task_id
        if task_cfg[task]["features_h5path1"] not in task_feature_reader1:
            task_feature_reader1[task_cfg[task]["features_h5path1"]] = None
        if task_cfg[task]["features_h5path2"] not in task_feature_reader2:
            task_feature_reader2[task_cfg[task]["features_h5path2"]] = None

    # initilzie the feature reader
    for features_h5path in task_feature_reader1.keys():
        if features_h5path != "":
            task_feature_reader1[features_h5path] = ImageFeaturesH5Reader(
                features_h5path, args.in_memory)
    for features_h5path in task_feature_reader2.keys():
        if features_h5path != "":
            task_feature_reader2[features_h5path] = ImageFeaturesH5Reader(
                features_h5path, args.in_memory)

    task_datasets_train = {}
    task_datasets_val = {}
    task_dataloader_train = {}
    task_dataloader_val = {}
    task_ids = []
    task_batch_size = {}
    task_num_iters = {}

    for i, task_id in enumerate(ids):
        task = "TASK" + task_id
        task_name = task_cfg[task]["name"]
        task_ids.append(task)
        batch_size = task_cfg[task][
            "batch_size"] // args.gradient_accumulation_steps
        num_workers = args.num_workers
        if args.local_rank != -1:
            batch_size = int(batch_size / dist.get_world_size())
            num_workers = int(num_workers / dist.get_world_size())

        # num_workers = int(num_workers / len(ids))
        logger.info("Loading %s Dataset with batch size %d" %
                    (task_cfg[task]["name"], batch_size))

        task_datasets_train[task] = None
        if "train" in split:
            task_datasets_train[task] = DatasetMapTrain[task_name](
                task=task_cfg[task]["name"],
                dataroot=task_cfg[task]["dataroot"],
                annotations_jsonpath=task_cfg[task]
                ["train_annotations_jsonpath"],
                split=task_cfg[task]["train_split"],
                image_features_reader=task_feature_reader1[
                    task_cfg[task]["features_h5path1"]],
                gt_image_features_reader=task_feature_reader2[
                    task_cfg[task]["features_h5path2"]],
                tokenizer=tokenizer,
                bert_model=args.bert_model,
                clean_datasets=args.clean_train_sets,
                padding_index=0,
                max_seq_length=task_cfg[task]["max_seq_length"],
                max_region_num=task_cfg[task]["max_region_num"],
            )

        task_datasets_val[task] = None
        if "val" in split:
            task_datasets_val[task] = DatasetMapTrain[task_name](
                task=task_cfg[task]["name"],
                dataroot=task_cfg[task]["dataroot"],
                annotations_jsonpath=task_cfg[task]
                ["val_annotations_jsonpath"],
                split=task_cfg[task]["val_split"],
                image_features_reader=task_feature_reader1[
                    task_cfg[task]["features_h5path1"]],
                gt_image_features_reader=task_feature_reader2[
                    task_cfg[task]["features_h5path2"]],
                tokenizer=tokenizer,
                bert_model=args.bert_model,
                clean_datasets=args.clean_train_sets,
                padding_index=0,
                max_seq_length=task_cfg[task]["max_seq_length"],
                max_region_num=task_cfg[task]["max_region_num"],
            )

        task_num_iters[task] = 0
        task_batch_size[task] = 0
        if "train" in split:
            if args.local_rank == -1:
                train_sampler = RandomSampler(task_datasets_train[task])
            else:
                # TODO: check if this works with current data generator from disk that relies on next(file)
                # (it doesn't return item back by index)
                train_sampler = DistributedSampler(task_datasets_train[task])

            task_dataloader_train[task] = DataLoader(
                task_datasets_train[task],
                sampler=train_sampler,
                batch_size=batch_size,
                num_workers=num_workers,
                pin_memory=True,
            )

            task_num_iters[task] = len(task_dataloader_train[task])
            task_batch_size[task] = batch_size

        if "val" in split:
            task_dataloader_val[task] = DataLoader(
                task_datasets_val[task],
                shuffle=False,
                batch_size=batch_size,
                num_workers=2,
                pin_memory=True,
            )

    return (
        task_batch_size,
        task_num_iters,
        task_ids,
        task_datasets_train,
        task_datasets_val,
        task_dataloader_train,
        task_dataloader_val,
    )
Exemplo n.º 29
0
    def load(cls,
             pretrained_model_name_or_path,
             tokenizer_class=None,
             use_fast=False,
             **kwargs):
        """
        Enables loading of different Tokenizer classes with a uniform interface. Either infer the class from
        `pretrained_model_name_or_path` or define it manually via `tokenizer_class`.

        :param pretrained_model_name_or_path:  The path of the saved pretrained model or its name (e.g. `bert-base-uncased`)
        :type pretrained_model_name_or_path: str
        :param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`)
        :type tokenizer_class: str
        :param use_fast: (Optional, False by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or
            use the Python one (False).
            Only DistilBERT, BERT and Electra fast tokenizers are supported.
        :type use_fast: bool
        :param kwargs:
        :return: Tokenizer
        """

        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
        # guess tokenizer type from name
        if tokenizer_class is None:
            if "albert" in pretrained_model_name_or_path.lower():
                tokenizer_class = "AlbertTokenizer"
            elif "xlm-roberta" in pretrained_model_name_or_path.lower():
                tokenizer_class = "XLMRobertaTokenizer"
            elif "roberta" in pretrained_model_name_or_path.lower():
                tokenizer_class = "RobertaTokenizer"
            elif 'codebert' in pretrained_model_name_or_path.lower():
                if "mlm" in pretrained_model_name_or_path.lower():
                    raise NotImplementedError(
                        "MLM part of codebert is currently not supported in FARM"
                    )
                else:
                    tokenizer_class = "RobertaTokenizer"
            elif "camembert" in pretrained_model_name_or_path.lower(
            ) or "umberto" in pretrained_model_name_or_path:
                tokenizer_class = "CamembertTokenizer"
            elif "distilbert" in pretrained_model_name_or_path.lower():
                tokenizer_class = "DistilBertTokenizer"
            elif "bert" in pretrained_model_name_or_path.lower():
                tokenizer_class = "BertTokenizer"
            elif "xlnet" in pretrained_model_name_or_path.lower():
                tokenizer_class = "XLNetTokenizer"
            elif "electra" in pretrained_model_name_or_path.lower():
                tokenizer_class = "ElectraTokenizer"
            elif "word2vec" in pretrained_model_name_or_path.lower() or \
                    "glove" in pretrained_model_name_or_path.lower() or \
                    "fasttext" in pretrained_model_name_or_path.lower():
                tokenizer_class = "EmbeddingTokenizer"
            elif "minilm" in pretrained_model_name_or_path.lower():
                tokenizer_class = "BertTokenizer"
            elif "dpr-question_encoder" in pretrained_model_name_or_path.lower(
            ):
                tokenizer_class = "DPRQuestionEncoderTokenizer"
            elif "dpr-ctx_encoder" in pretrained_model_name_or_path.lower():
                tokenizer_class = "DPRContextEncoderTokenizer"
            else:
                raise ValueError(
                    f"Could not infer tokenizer_class from name '{pretrained_model_name_or_path}'. Set "
                    f"arg `tokenizer_class` in Tokenizer.load() to one of: AlbertTokenizer, "
                    f"XLMRobertaTokenizer, RobertaTokenizer, DistilBertTokenizer, BertTokenizer, or "
                    f"XLNetTokenizer.")
            logger.info(f"Loading tokenizer of type '{tokenizer_class}'")
        # return appropriate tokenizer object
        ret = None
        if tokenizer_class == "AlbertTokenizer":
            if use_fast:
                logger.error(
                    'AlbertTokenizerFast is not supported! Using AlbertTokenizer instead.'
                )
                ret = AlbertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
            else:
                ret = AlbertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
        elif tokenizer_class == "XLMRobertaTokenizer":
            if use_fast:
                logger.error(
                    'XLMRobertaTokenizerFast is not supported! Using XLMRobertaTokenizer instead.'
                )
                ret = XLMRobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = XLMRobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "RobertaTokenizer" in tokenizer_class:  # because it also might be fast tokekenizer we use "in"
            if use_fast:
                logger.error(
                    'RobertaTokenizerFast is not supported! Using RobertaTokenizer instead.'
                )
                ret = RobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = RobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "DistilBertTokenizer" in tokenizer_class:  # because it also might be fast tokekenizer we use "in"
            if use_fast:
                ret = DistilBertTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DistilBertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "BertTokenizer" in tokenizer_class:  # because it also might be fast tokekenizer we use "in"
            if use_fast:
                ret = BertTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = BertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "XLNetTokenizer":
            if use_fast:
                logger.error(
                    'XLNetTokenizerFast is not supported! Using XLNetTokenizer instead.'
                )
                ret = XLNetTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
            else:
                ret = XLNetTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
        elif "ElectraTokenizer" in tokenizer_class:  # because it also might be fast tokekenizer we use "in"
            if use_fast:
                ret = ElectraTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = ElectraTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "EmbeddingTokenizer":
            if use_fast:
                logger.error(
                    'EmbeddingTokenizerFast is not supported! Using EmbeddingTokenizer instead.'
                )
                ret = EmbeddingTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = EmbeddingTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "CamembertTokenizer":
            if use_fast:
                logger.error(
                    'CamembertTokenizerFast is not supported! Using CamembertTokenizer instead.'
                )
                ret = CamembertTokenizer._from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = CamembertTokenizer._from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "DPRQuestionEncoderTokenizer" or tokenizer_class == "DPRQuestionEncoderTokenizerFast":
            if use_fast or tokenizer_class == "DPRQuestionEncoderTokenizerFast":
                ret = DPRQuestionEncoderTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DPRQuestionEncoderTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "DPRContextEncoderTokenizer" or tokenizer_class == "DPRContextEncoderTokenizerFast":
            if use_fast or tokenizer_class == "DPRContextEncoderTokenizerFast":
                ret = DPRContextEncoderTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DPRContextEncoderTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        if ret is None:
            raise Exception("Unable to load tokenizer")
        else:
            return ret
Exemplo n.º 30
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")

    parser.add_argument('--kshot',
                        type=float,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--use_mixup",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--beta_sampling_times',
                        type=int,
                        default=10,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    device = torch.device("cuda")
    n_gpu = torch.cuda.device_count()

    args.train_batch_size = args.train_batch_size * max(1, n_gpu)
    args.eval_batch_size = args.train_batch_size * max(1, n_gpu)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    train_examples = processor.get_RTE_as_train_k_shot(
        '/export/home/Dataset/glue_data/RTE/train.tsv',
        args.kshot)  #train_pu_half_v1.txt
    dev_examples = processor.get_RTE_as_dev(
        '/export/home/Dataset/glue_data/RTE/dev.tsv')
    test_examples = processor.get_RTE_as_test(
        '/export/home/Dataset/RTE/test_RTE_1235.txt')
    label_list = ["entailment", "not_entailment"]
    num_labels = len(label_list)
    print('num_labels:', num_labels, 'training size:', len(train_examples),
          'dev size:', len(dev_examples), 'test size:', len(test_examples))

    num_train_optimization_steps = None
    num_train_optimization_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps) * args.num_train_epochs
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    model = RobertaForSequenceClassification(num_labels)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.load_state_dict(
        torch.load(
            '/export/home/Dataset/BERT_pretrained_mine/mixup_wenpeng/batchMixup_pretrain_kshot_'
            + str(args.kshot) + '_dev_acc_seed_' + str(args.seed) + '.pt'))
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)
        '''load dev set'''
        dev_features = convert_examples_to_features(
            dev_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        dev_all_input_ids = torch.tensor([f.input_ids for f in dev_features],
                                         dtype=torch.long)
        dev_all_input_mask = torch.tensor([f.input_mask for f in dev_features],
                                          dtype=torch.long)
        dev_all_segment_ids = torch.tensor(
            [f.segment_ids for f in dev_features], dtype=torch.long)
        dev_all_label_ids = torch.tensor([f.label_id for f in dev_features],
                                         dtype=torch.long)

        dev_data = TensorDataset(dev_all_input_ids, dev_all_input_mask,
                                 dev_all_segment_ids, dev_all_label_ids)
        dev_sampler = SequentialSampler(dev_data)
        dev_dataloader = DataLoader(dev_data,
                                    sampler=dev_sampler,
                                    batch_size=args.eval_batch_size)
        '''load test set'''
        test_features = convert_examples_to_features(
            test_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        eval_all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                          dtype=torch.long)
        eval_all_input_mask = torch.tensor(
            [f.input_mask for f in test_features], dtype=torch.long)
        eval_all_segment_ids = torch.tensor(
            [f.segment_ids for f in test_features], dtype=torch.long)
        eval_all_label_ids = torch.tensor([f.label_id for f in test_features],
                                          dtype=torch.long)

        eval_data = TensorDataset(eval_all_input_ids, eval_all_input_mask,
                                  eval_all_segment_ids, eval_all_label_ids)
        eval_sampler = SequentialSampler(eval_data)
        test_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        train_sampler = RandomSampler(train_data)

        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        iter_co = 0
        final_test_performance = 0.0
        for epoch_i in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                for sample_i in range(args.beta_sampling_times):
                    lambda_vec = beta.rvs(0.4, 0.4, size=1)[0]
                    loss = model(input_ids,
                                 input_mask,
                                 label_ids,
                                 lambda_vec,
                                 is_train=True,
                                 use_mixup=args.use_mixup)
                    if n_gpu > 1:
                        loss = loss.mean()  # mean() to average on multi-gpu.
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    loss.backward()

                    tr_loss += loss.item()
                    nb_tr_examples += input_ids.size(0)
                    nb_tr_steps += 1

                    optimizer.step()
                    optimizer.zero_grad()
                global_step += 1
                iter_co += 1
                # if iter_co %20==0:
                if iter_co % len(train_dataloader) == 0:
                    '''
                    start evaluate on dev set after this epoch
                    '''
                    model.eval()

                    for idd, dev_or_test_dataloader in enumerate(
                        [dev_dataloader, test_dataloader]):

                        if idd == 0:
                            logger.info("***** Running dev *****")
                            logger.info("  Num examples = %d",
                                        len(dev_examples))
                        else:
                            logger.info("***** Running test *****")
                            logger.info("  Num examples = %d",
                                        len(test_examples))
                        # logger.info("  Batch size = %d", args.eval_batch_size)

                        eval_loss = 0
                        nb_eval_steps = 0
                        preds = []
                        gold_label_ids = []
                        # print('Evaluating...')
                        for input_ids, input_mask, segment_ids, label_ids in dev_or_test_dataloader:
                            input_ids = input_ids.to(device)
                            input_mask = input_mask.to(device)
                            segment_ids = segment_ids.to(device)
                            label_ids = label_ids.to(device)
                            gold_label_ids += list(
                                label_ids.detach().cpu().numpy())

                            with torch.no_grad():
                                logits = model(input_ids,
                                               input_mask,
                                               None,
                                               None,
                                               is_train=False,
                                               use_mixup=False)
                            if len(preds) == 0:
                                preds.append(logits.detach().cpu().numpy())
                            else:
                                preds[0] = np.append(
                                    preds[0],
                                    logits.detach().cpu().numpy(),
                                    axis=0)

                        preds = preds[0]

                        pred_probs = softmax(preds, axis=1)
                        pred_label_ids = list(np.argmax(pred_probs, axis=1))

                        gold_label_ids = gold_label_ids
                        assert len(pred_label_ids) == len(gold_label_ids)
                        hit_co = 0
                        for k in range(len(pred_label_ids)):
                            if pred_label_ids[k] == gold_label_ids[k]:
                                hit_co += 1
                        test_acc = hit_co / len(gold_label_ids)

                        if idd == 0:  # this is dev
                            if test_acc > max_dev_acc:
                                max_dev_acc = test_acc
                                print('\ndev acc:', test_acc, ' max_dev_acc:',
                                      max_dev_acc, '\n')

                            else:
                                print('\ndev acc:', test_acc, ' max_dev_acc:',
                                      max_dev_acc, '\n')
                                break
                        else:  # this is test
                            if test_acc > max_test_acc:
                                max_test_acc = test_acc

                            final_test_performance = test_acc
                            print('\ntest acc:', test_acc, ' max_test_acc:',
                                  max_test_acc, '\n')
        print('final_test_performance:', final_test_performance)