Ejemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument(
        "--do_balance",
        action="store_true",
        help="Set this flag if you want to use the balanced choose function")
    parser.add_argument(
        "--push_message",
        action="store_true",
        help="set this flag if you want to push message to your phone")
    parser.add_argument(
        "--top_k",
        default=500,
        type=int,
        help=
        "Set the num of top k pseudo labels Teacher will choose for Student to learn"
    )
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument(
        "--num_student_train_epochs",
        type=float,
        default=3.0,
        help="Total number of student model training epochs to perform.")
    parser.add_argument("--threshold",
                        type=float,
                        default=0.05,
                        help="threshold for improvenesss of model")
    parser.add_argument("--selection_function",
                        type=str,
                        default="random",
                        help="choose the selectionfunction")
    parser.add_argument("--alpha",
                        type=float,
                        default=0.33,
                        help="the weights of the TL model in the final model")
    parser.add_argument("--ft_true",
                        action="store_true",
                        help="fine tune the student model with true data")
    parser.add_argument("--ft_pseudo",
                        action="store_true",
                        help="fine tune the student model with pseudo data")
    parser.add_argument(
        "--ft_both",
        action="store_true",
        help="fine-tune the student model with both true and pseudo data")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "aus": OOCLAUSProcessor,
        "dbpedia": DBpediaProcessor,
        "trec": TrecProcessor,
        "yelp": YelpProcessor,
    }

    num_labels_task = {
        "aus": 33,
        "dbpedia": len(DBpediaProcessor().get_labels()),
        "trec": len(TrecProcessor().get_labels()),
        "yelp": len(YelpProcessor().get_labels()),
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    logger.info("***** Build tri model *****")
    # Prepare model
    cache_dir = args.cache_dir

    model = create_tri_model(args, cache_dir, num_labels, device)

    if args.do_train:
        # step 0: load train examples
        logger.info("Cook training and dev data for teacher model")
        train_examples = processor.get_train_examples(args.data_dir)
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info(" Num Training Examples = %d", len(train_examples))
        logger.info(" Train Batch Size = %d", args.train_batch_size)

        input_ids_train = np.array([f.input_ids for f in train_features])
        input_mask_train = np.array([f.input_mask for f in train_features])
        segment_ids_train = np.array([f.segment_ids for f in train_features])
        label_ids_train = np.array([f.label_id for f in train_features])
        train_data_loader = load_train_data(args, input_ids_train,
                                            input_mask_train,
                                            segment_ids_train, label_ids_train)

        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        logger.info(" Num Eval Examples = %d", len(eval_examples))
        logger.info(" Eval Batch Size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        eval_sampler = SequentialSampler(eval_data)
        eval_data_loader = DataLoader(eval_data,
                                      sampler=eval_sampler,
                                      batch_size=args.eval_batch_size)

        # step 1: train the Tri model with labeled data
        logger.info("***** Running train TL model with labeled data *****")
        model = init_tri_model(model, args, n_gpu, train_data_loader, device,
                               args.num_train_epochs, logger)
        acc1 = evaluate_model(model, device, eval_data_loader, logger)

        # step 3: Tri-training

        model = TriTraining(model, args, device, n_gpu, 3, processor,
                            label_list, tokenizer)

        # step 4: evalute model
        acc = evaluate_model(model, device, eval_data_loader, logger)

        print(acc1, acc)
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                             "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--do_balance",
                        action="store_true",
                        help="Set this flag if you want to use the balanced choose function")
    parser.add_argument("--push_message",
                        action = "store_true",
                        help="set this flag if you want to push message to your phone")
    parser.add_argument("--top_k",
                        default=500,
                        type=int,
                        help="Set the num of top k pseudo labels Teacher will choose for Student to learn")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--num_student_train_epochs",
                        type=float,
                        default=3.0,
                        help="Total number of student model training epochs to perform.")
    parser.add_argument("--threshold",
                        type=float,
                        default=0.05,
                        help="threshold for improvenesss of model")
    parser.add_argument("--selection_function",
                        type=str,
                        default= "random",
                        help = "choose the selectionfunction")
    parser.add_argument("--alpha",
                        type=float,
                        default=0.33,
                        help = "the weights of the TL model in the final model")
    parser.add_argument("--ft_true",
                        action="store_true",
                        help="fine tune the student model with true data")
    parser.add_argument("--ft_pseudo",
                        action="store_true",
                        help="fine tune the student model with pseudo data")
    parser.add_argument("--ft_both",
                        action="store_true",
                        help="fine-tune the student model with both true and pseudo data")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "aus": OOCLAUSProcessor,
        "dbpedia": DBpediaProcessor,
        "trec": TrecProcessor,
        "yelp": YelpProcessor,
    }

    num_labels_task = {
        "aus": 33,
        "dbpedia": len(DBpediaProcessor().get_labels()),
        "trec": len(TrecProcessor().get_labels()),
        "yelp": len(YelpProcessor().get_labels()),
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    logger.info("***** Build teacher(label) model *****")
    # Prepare model
    cache_dir = args.cache_dir
    model_TL  = create_model(args, cache_dir, num_labels, device)

    logger.info("***** Build teacher(unlabel) model *****")
    cache_dir = args.cache_dir
    model_TU  = create_model(args, cache_dir, num_labels, device)
    logger.info("***** Build student model *****")

    model_student = create_model(args, cache_dir, num_labels, device)
    
    logger.info("***** Finish TL, TU and Student model building *****")


    if args.do_train:
        # step 0: load train examples
        logger.info("Cook training and dev data for teacher model")
        train_examples = processor.get_train_examples(args.data_dir)
        train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer)
        logger.info(" Num Training Examples = %d", len(train_examples))
        logger.info(" Train Batch Size = %d", args.train_batch_size)

        input_ids_train = np.array([f.input_ids for f in train_features])
        input_mask_train = np.array([f.input_mask for f in train_features])
        segment_ids_train = np.array([f.segment_ids for f in train_features])
        label_ids_train = np.array([f.label_id for f in train_features])
        train_data_loader = load_train_data(args, input_ids_train, input_mask_train, segment_ids_train, label_ids_train)
        
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer
        )
        logger.info(" Num Eval Examples = %d", len(eval_examples))
        logger.info(" Eval Batch Size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        eval_sampler = SequentialSampler(eval_data)
        eval_data_loader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
        print(len(all_input_ids))

        
        # step 1: train the TL model with labeled data
        logger.info("***** Running train TL model with labeled data *****")
        model_TL = train(model_TL, args, n_gpu, train_data_loader, device, args.num_train_epochs, logger)
        model_TL.to(device)

        logger.info("***** Evaluate TL model *****")
        model_TL_accuracy = evaluate_model(model_TL, device, eval_data_loader, logger)

        # Step 2: predict the val_set
        logger.info("***** Product pseudo label from TL model *****")
        probas_val = predict_model(model_TL, args, eval_data_loader, device)
        print(len(probas_val))
    

        # Step 3: choose top-k data_val and reset train_data
        if args.do_balance:
            selection = BalanceTopkSelectionFunction()
        else:
            selection = TopkSelectionFunction()
        permutation = selection.select(probas_val, args.top_k)
        print("permutation", len(permutation))
        input_ids_TU, input_mask_TU, segment_ids_TU, label_ids_TU = sample_data(all_input_ids, all_input_mask, all_segment_ids, probas_val, permutation)
        print("input_ids_TU size:", len(input_ids_TU))
        logger.info("Pseudo label distribution = %s", collections.Counter(label_ids_TU))

        #print("labels_TU examples", label_ids_TU)
        # step 4: train TU model with Pseudo labels
        logger.info("***** Running train TU model with pseudo data *****")
        train_data_loader_TU = load_train_data(args, input_ids_TU, input_mask_TU, segment_ids_TU, label_ids_TU)
    
        model_TU = train(model_TU, args, n_gpu, train_data_loader_TU, device, args.num_train_epochs, logger)
        model_TU.to(device)

        logger.info("***** Evaluate TU model  *****")
        model_TU_accuracy = evaluate_model(model_TU, device, eval_data_loader, logger)

        # step 5: init student model with mix weights from TL and TU model
        logger.info("***** Init student model with weights from TL and TU model *****")
        # model_student = init_student_weights(model_TL, model_TU, model_student, args.alpha)
        model_student = create_student_model(args, cache_dir, num_labels, device, model_TL.state_dict(), model_TU.state_dict())
        model_student.to(device)
        print(model_student.state_dict())

        # mix train data and pesudo data to create fine-tune dataset
        train_examples = processor.get_train_examples(args.data_dir)
        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer)
        input_ids_train = np.array([f.input_ids for f in train_features])
        input_mask_train = np.array([f.input_mask for f in train_features])
        segment_ids_train = np.array([f.segment_ids for f in train_features])
        label_ids_train = np.array([f.label_id for f in train_features])

        input_ids_ft = np.concatenate((input_ids_train, np.array(input_ids_TU)), axis=0)
        input_mask_ft = np.concatenate((input_mask_train, np.array(input_mask_TU)), axis=0)
        segment_ids_ft = np.concatenate((segment_ids_train, np.array(segment_ids_TU)), axis=0)
        label_ids_ft = np.concatenate((label_ids_train, np.array(label_ids_TU)), axis=0)

        p = np.random.permutation(len(input_ids_ft))
        input_ids_ft = input_ids_ft[p]
        input_mask_ft = input_mask_ft[p]
        segment_ids_ft = segment_ids_ft[p]
        label_ids_ft = label_ids_ft[p]

        fine_tune_dataloader = load_train_data(args, input_ids_ft, input_mask_ft, segment_ids_ft, label_ids_ft)

        if args.ft_true:
            # step 6: train student model with train data
            logger.info("***** Running train student model with train data *****")
            model_student = train(model_student, args, n_gpu, train_data_loader, device, args.num_student_train_epochs, logger)
            model_student.to(device)

        if args.ft_pseudo:
            # step 7: train student model with Pseudo labels
            logger.info("***** Running train student model with Pseudo data *****")
            model_student = train(model_student, args, n_gpu, train_data_loader_TU, device, args.num_student_train_epochs, logger)
            model_student.to(device)
        
        if args.ft_both:
            # step 8: train student model with both train and Peudo data
            logger.info("***** Running train student model with both train and Pseudo data *****")
            model_student = train(model_student, args, n_gpu, fine_tune_dataloader, device, args.num_student_train_epochs, logger)
            model_student.to(device)
            
        
        logger.info("***** Evaluate student model  *****")
        model_student_accuracy = evaluate_model(model_student, device, eval_data_loader, logger)
    

        results = [model_TL_accuracy, model_TU_accuracy, model_student_accuracy]
        print(results)

        if args.push_message:
            api = "https://sc.ftqq.com/SCU47715T1085ec82936ebfe2723aaa3095bb53505ca315d2865a0.send"
            title = args.task_name
            if args.ft_true:
                title += " ft_true "
            if args.ft_pseudo:
                title += "ft_pseudo"
            content = ""
            content += "Params: alpha:{} \n".format(str(args.alpha)) 
            content += "model_TL: " + str(model_TL_accuracy) + "\n"
            content += "model_TU: " + str(model_TU_accuracy) + "\n"
            content += "model_student: " + str(model_student_accuracy) + "\n"
            data = {
                "text":title,
                "desp":content
            }
            requests.post(api, data=data)
Ejemplo n.º 3
0
def TriTraining(model, args, device, n_gpu, epochs, processor, label_list,
                tokenizer):
    '''
        Tri-Trainint Process
    '''
    train_size = 300
    # initial the train set L
    train_examples = processor.get_train_examples(args.data_dir)
    train_features = convert_examples_to_features(train_examples, label_list,
                                                  args.max_seq_length,
                                                  tokenizer)

    input_ids_train = np.array([f.input_ids for f in train_features])
    input_mask_train = np.array([f.input_mask for f in train_features])
    segment_ids_train = np.array([f.segment_ids for f in train_features])
    label_ids_train = np.array([f.label_id for f in train_features])
    train_data_loader = load_train_data(args, input_ids_train,
                                        input_mask_train, segment_ids_train,
                                        label_ids_train)

    # initial the unlabeled set U
    eval_examples = processor.get_dev_examples(args.data_dir)
    eval_features = convert_examples_to_features(eval_examples, label_list,
                                                 args.max_seq_length,
                                                 tokenizer)
    if train_size > len(eval_features):
        train_size = len(eval_features)
    unlabeled_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                       dtype=torch.long)[:train_size]
    unlabeled_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                        dtype=torch.long)[:train_size]
    unlabeled_segment_ids = torch.tensor(
        [f.segment_ids for f in eval_features], dtype=torch.long)[:train_size]
    unlabeled_label_ids = torch.tensor([f.label_id for f in eval_features],
                                       dtype=torch.long)[:train_size]
    eval_data = TensorDataset(unlabeled_input_ids, unlabeled_input_mask,
                              unlabeled_segment_ids, unlabeled_label_ids)
    eval_sampler = SequentialSampler(eval_data)
    eval_data_loader = DataLoader(eval_data,
                                  sampler=eval_sampler,
                                  batch_size=args.eval_batch_size)

    logger.info("***** Tri Training *****")
    cnt = 0
    for cnt in range(1, 3 * epochs + 1):
        trainset_index = []
        model.eval()
        predict_results_j = []
        predict_results_k = []

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_data_loader):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                logits1, logits2, logits3 = model(input_ids, segment_ids,
                                                  input_mask)
                if cnt % 3 == 1:
                    logits_j = logits2
                    logits_k = logits3
                elif cnt % 3 == 2:
                    logits_j = logits1
                    logits_k = logits3
                elif cnt % 3 == 0:
                    logits_j = logits1
                    logits_k = logits2
            logits_j = logits_j.detach().cpu().numpy()
            logits_k = logits_k.detach().cpu().numpy()
            predict_results_j.extend(np.argmax(logits_j, axis=1))
            predict_results_k.extend(np.argmax(logits_k, axis=1))

        # choose  p2(x) == p3(x)
        # print("predict_result_j", predict_results_j)
        # print("predict_result_k", predict_results_k)
        for i in range(len(predict_results_j)):
            if predict_results_j[i] == predict_results_k[i]:
                trainset_index.append(i)

        # doing the permutation
        permutation = np.array(trainset_index)
        print("permutation size ", len(permutation))
        if len(permutation) == 0:
            train_data_loader = load_train_data(args, input_ids_train,
                                                input_mask_train,
                                                segment_ids_train,
                                                label_ids_train)
        else:
            if cnt % 3 == 0:
                input_ids_train_new = unlabeled_input_ids.numpy()[permutation]
                input_mask_train_new = unlabeled_input_mask.numpy(
                )[permutation]
                segment_ids_train_new = unlabeled_segment_ids.numpy(
                )[permutation]
                label_ids_train_new = np.array(predict_results_j)[permutation]
                train_data_loader = load_train_data(args, input_ids_train_new,
                                                    input_mask_train_new,
                                                    segment_ids_train_new,
                                                    label_ids_train_new)
                model = train_model(model, args, n_gpu, train_data_loader,
                                    device, args.num_student_train_epochs,
                                    logger, 3)
            else:
                # print("input_ids_train shape:", input_ids_train.shape)
                # print("unlabeled_input_ids shape", unlabeled_input_ids[permutation].shape)
                input_ids_train_new = np.concatenate(
                    (input_ids_train, unlabeled_input_ids[permutation]),
                    axis=0)
                input_mask_train_new = np.concatenate(
                    (input_mask_train, unlabeled_input_mask[permutation]),
                    axis=0)
                segment_ids_train_new = np.concatenate(
                    (segment_ids_train, unlabeled_segment_ids[permutation]),
                    axis=0)
                label_ids_train_new = np.concatenate(
                    (label_ids_train,
                     np.array(predict_results_j)[permutation]),
                    axis=0)
                train_data_loader = load_train_data(args, input_ids_train_new,
                                                    input_mask_train_new,
                                                    segment_ids_train_new,
                                                    label_ids_train_new)
                model = train_model(model, args, n_gpu, train_data_loader,
                                    device, args.num_student_train_epochs,
                                    logger, cnt % 3)

    return model