def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " "bert-base-multilingual-cased, bert-base-chinese.") parser.add_argument("--task_name", default=None, type=str, required=True, help="The name of the task to train.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) ## Other parameters parser.add_argument( "--cache_dir", default="", type=str, help= "Where do you want to store the pre-trained models downloaded from s3") parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument( "--do_balance", action="store_true", help="Set this flag if you want to use the balanced choose function") parser.add_argument( "--push_message", action="store_true", help="set this flag if you want to push message to your phone") parser.add_argument( "--top_k", default=500, type=int, help= "Set the num of top k pseudo labels Teacher will choose for Student to learn" ) parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=8, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument( "--num_student_train_epochs", type=float, default=3.0, help="Total number of student model training epochs to perform.") parser.add_argument("--threshold", type=float, default=0.05, help="threshold for improvenesss of model") parser.add_argument("--selection_function", type=str, default="random", help="choose the selectionfunction") parser.add_argument("--alpha", type=float, default=0.33, help="the weights of the TL model in the final model") parser.add_argument("--ft_true", action="store_true", help="fine tune the student model with true data") parser.add_argument("--ft_pseudo", action="store_true", help="fine tune the student model with pseudo data") parser.add_argument( "--ft_both", action="store_true", help="fine-tune the student model with both true and pseudo data") parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") args = parser.parse_args() if args.server_ip and args.server_port: # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script import ptvsd print("Waiting for debugger attach") ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.wait_for_attach() processors = { "aus": OOCLAUSProcessor, "dbpedia": DBpediaProcessor, "trec": TrecProcessor, "yelp": YelpProcessor, } num_labels_task = { "aus": 33, "dbpedia": len(DBpediaProcessor().get_labels()), "trec": len(TrecProcessor().get_labels()), "yelp": len(YelpProcessor().get_labels()), } if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if os.path.exists(args.output_dir) and os.listdir( args.output_dir) and args.do_train: raise ValueError( "Output directory ({}) already exists and is not empty.".format( args.output_dir)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) task_name = args.task_name.lower() if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() num_labels = num_labels_task[task_name] label_list = processor.get_labels() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) logger.info("***** Build tri model *****") # Prepare model cache_dir = args.cache_dir model = create_tri_model(args, cache_dir, num_labels, device) if args.do_train: # step 0: load train examples logger.info("Cook training and dev data for teacher model") train_examples = processor.get_train_examples(args.data_dir) train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer) logger.info(" Num Training Examples = %d", len(train_examples)) logger.info(" Train Batch Size = %d", args.train_batch_size) input_ids_train = np.array([f.input_ids for f in train_features]) input_mask_train = np.array([f.input_mask for f in train_features]) segment_ids_train = np.array([f.segment_ids for f in train_features]) label_ids_train = np.array([f.label_id for f in train_features]) train_data_loader = load_train_data(args, input_ids_train, input_mask_train, segment_ids_train, label_ids_train) eval_examples = processor.get_dev_examples(args.data_dir) eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer) logger.info(" Num Eval Examples = %d", len(eval_examples)) logger.info(" Eval Batch Size = %d", args.eval_batch_size) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) eval_sampler = SequentialSampler(eval_data) eval_data_loader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) # step 1: train the Tri model with labeled data logger.info("***** Running train TL model with labeled data *****") model = init_tri_model(model, args, n_gpu, train_data_loader, device, args.num_train_epochs, logger) acc1 = evaluate_model(model, device, eval_data_loader, logger) # step 3: Tri-training model = TriTraining(model, args, device, n_gpu, 3, processor, label_list, tokenizer) # step 4: evalute model acc = evaluate_model(model, device, eval_data_loader, logger) print(acc1, acc)
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--data_dir", default=None, type=str, required=True, help="The input data dir. Should contain the .tsv files (or other data files) for the task.") parser.add_argument("--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " "bert-base-multilingual-cased, bert-base-chinese.") parser.add_argument("--task_name", default=None, type=str, required=True, help="The name of the task to train.") parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.") ## Other parameters parser.add_argument("--cache_dir", default="", type=str, help="Where do you want to store the pre-trained models downloaded from s3") parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--do_balance", action="store_true", help="Set this flag if you want to use the balanced choose function") parser.add_argument("--push_message", action = "store_true", help="set this flag if you want to push message to your phone") parser.add_argument("--top_k", default=500, type=int, help="Set the num of top k pseudo labels Teacher will choose for Student to learn") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=8, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--loss_scale', type=float, default=0, help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument("--num_student_train_epochs", type=float, default=3.0, help="Total number of student model training epochs to perform.") parser.add_argument("--threshold", type=float, default=0.05, help="threshold for improvenesss of model") parser.add_argument("--selection_function", type=str, default= "random", help = "choose the selectionfunction") parser.add_argument("--alpha", type=float, default=0.33, help = "the weights of the TL model in the final model") parser.add_argument("--ft_true", action="store_true", help="fine tune the student model with true data") parser.add_argument("--ft_pseudo", action="store_true", help="fine tune the student model with pseudo data") parser.add_argument("--ft_both", action="store_true", help="fine-tune the student model with both true and pseudo data") parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") args = parser.parse_args() if args.server_ip and args.server_port: # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script import ptvsd print("Waiting for debugger attach") ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.wait_for_attach() processors = { "aus": OOCLAUSProcessor, "dbpedia": DBpediaProcessor, "trec": TrecProcessor, "yelp": YelpProcessor, } num_labels_task = { "aus": 33, "dbpedia": len(DBpediaProcessor().get_labels()), "trec": len(TrecProcessor().get_labels()), "yelp": len(YelpProcessor().get_labels()), } if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_eval: raise ValueError("At least one of `do_train` or `do_eval` must be True.") if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) task_name = args.task_name.lower() if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() num_labels = num_labels_task[task_name] label_list = processor.get_labels() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) logger.info("***** Build teacher(label) model *****") # Prepare model cache_dir = args.cache_dir model_TL = create_model(args, cache_dir, num_labels, device) logger.info("***** Build teacher(unlabel) model *****") cache_dir = args.cache_dir model_TU = create_model(args, cache_dir, num_labels, device) logger.info("***** Build student model *****") model_student = create_model(args, cache_dir, num_labels, device) logger.info("***** Finish TL, TU and Student model building *****") if args.do_train: # step 0: load train examples logger.info("Cook training and dev data for teacher model") train_examples = processor.get_train_examples(args.data_dir) train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer) logger.info(" Num Training Examples = %d", len(train_examples)) logger.info(" Train Batch Size = %d", args.train_batch_size) input_ids_train = np.array([f.input_ids for f in train_features]) input_mask_train = np.array([f.input_mask for f in train_features]) segment_ids_train = np.array([f.segment_ids for f in train_features]) label_ids_train = np.array([f.label_id for f in train_features]) train_data_loader = load_train_data(args, input_ids_train, input_mask_train, segment_ids_train, label_ids_train) eval_examples = processor.get_dev_examples(args.data_dir) eval_features = convert_examples_to_features( eval_examples, label_list, args.max_seq_length, tokenizer ) logger.info(" Num Eval Examples = %d", len(eval_examples)) logger.info(" Eval Batch Size = %d", args.eval_batch_size) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) eval_sampler = SequentialSampler(eval_data) eval_data_loader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) print(len(all_input_ids)) # step 1: train the TL model with labeled data logger.info("***** Running train TL model with labeled data *****") model_TL = train(model_TL, args, n_gpu, train_data_loader, device, args.num_train_epochs, logger) model_TL.to(device) logger.info("***** Evaluate TL model *****") model_TL_accuracy = evaluate_model(model_TL, device, eval_data_loader, logger) # Step 2: predict the val_set logger.info("***** Product pseudo label from TL model *****") probas_val = predict_model(model_TL, args, eval_data_loader, device) print(len(probas_val)) # Step 3: choose top-k data_val and reset train_data if args.do_balance: selection = BalanceTopkSelectionFunction() else: selection = TopkSelectionFunction() permutation = selection.select(probas_val, args.top_k) print("permutation", len(permutation)) input_ids_TU, input_mask_TU, segment_ids_TU, label_ids_TU = sample_data(all_input_ids, all_input_mask, all_segment_ids, probas_val, permutation) print("input_ids_TU size:", len(input_ids_TU)) logger.info("Pseudo label distribution = %s", collections.Counter(label_ids_TU)) #print("labels_TU examples", label_ids_TU) # step 4: train TU model with Pseudo labels logger.info("***** Running train TU model with pseudo data *****") train_data_loader_TU = load_train_data(args, input_ids_TU, input_mask_TU, segment_ids_TU, label_ids_TU) model_TU = train(model_TU, args, n_gpu, train_data_loader_TU, device, args.num_train_epochs, logger) model_TU.to(device) logger.info("***** Evaluate TU model *****") model_TU_accuracy = evaluate_model(model_TU, device, eval_data_loader, logger) # step 5: init student model with mix weights from TL and TU model logger.info("***** Init student model with weights from TL and TU model *****") # model_student = init_student_weights(model_TL, model_TU, model_student, args.alpha) model_student = create_student_model(args, cache_dir, num_labels, device, model_TL.state_dict(), model_TU.state_dict()) model_student.to(device) print(model_student.state_dict()) # mix train data and pesudo data to create fine-tune dataset train_examples = processor.get_train_examples(args.data_dir) train_features = convert_examples_to_features( train_examples, label_list, args.max_seq_length, tokenizer) input_ids_train = np.array([f.input_ids for f in train_features]) input_mask_train = np.array([f.input_mask for f in train_features]) segment_ids_train = np.array([f.segment_ids for f in train_features]) label_ids_train = np.array([f.label_id for f in train_features]) input_ids_ft = np.concatenate((input_ids_train, np.array(input_ids_TU)), axis=0) input_mask_ft = np.concatenate((input_mask_train, np.array(input_mask_TU)), axis=0) segment_ids_ft = np.concatenate((segment_ids_train, np.array(segment_ids_TU)), axis=0) label_ids_ft = np.concatenate((label_ids_train, np.array(label_ids_TU)), axis=0) p = np.random.permutation(len(input_ids_ft)) input_ids_ft = input_ids_ft[p] input_mask_ft = input_mask_ft[p] segment_ids_ft = segment_ids_ft[p] label_ids_ft = label_ids_ft[p] fine_tune_dataloader = load_train_data(args, input_ids_ft, input_mask_ft, segment_ids_ft, label_ids_ft) if args.ft_true: # step 6: train student model with train data logger.info("***** Running train student model with train data *****") model_student = train(model_student, args, n_gpu, train_data_loader, device, args.num_student_train_epochs, logger) model_student.to(device) if args.ft_pseudo: # step 7: train student model with Pseudo labels logger.info("***** Running train student model with Pseudo data *****") model_student = train(model_student, args, n_gpu, train_data_loader_TU, device, args.num_student_train_epochs, logger) model_student.to(device) if args.ft_both: # step 8: train student model with both train and Peudo data logger.info("***** Running train student model with both train and Pseudo data *****") model_student = train(model_student, args, n_gpu, fine_tune_dataloader, device, args.num_student_train_epochs, logger) model_student.to(device) logger.info("***** Evaluate student model *****") model_student_accuracy = evaluate_model(model_student, device, eval_data_loader, logger) results = [model_TL_accuracy, model_TU_accuracy, model_student_accuracy] print(results) if args.push_message: api = "https://sc.ftqq.com/SCU47715T1085ec82936ebfe2723aaa3095bb53505ca315d2865a0.send" title = args.task_name if args.ft_true: title += " ft_true " if args.ft_pseudo: title += "ft_pseudo" content = "" content += "Params: alpha:{} \n".format(str(args.alpha)) content += "model_TL: " + str(model_TL_accuracy) + "\n" content += "model_TU: " + str(model_TU_accuracy) + "\n" content += "model_student: " + str(model_student_accuracy) + "\n" data = { "text":title, "desp":content } requests.post(api, data=data)
def TriTraining(model, args, device, n_gpu, epochs, processor, label_list, tokenizer): ''' Tri-Trainint Process ''' train_size = 300 # initial the train set L train_examples = processor.get_train_examples(args.data_dir) train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer) input_ids_train = np.array([f.input_ids for f in train_features]) input_mask_train = np.array([f.input_mask for f in train_features]) segment_ids_train = np.array([f.segment_ids for f in train_features]) label_ids_train = np.array([f.label_id for f in train_features]) train_data_loader = load_train_data(args, input_ids_train, input_mask_train, segment_ids_train, label_ids_train) # initial the unlabeled set U eval_examples = processor.get_dev_examples(args.data_dir) eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer) if train_size > len(eval_features): train_size = len(eval_features) unlabeled_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)[:train_size] unlabeled_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)[:train_size] unlabeled_segment_ids = torch.tensor( [f.segment_ids for f in eval_features], dtype=torch.long)[:train_size] unlabeled_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)[:train_size] eval_data = TensorDataset(unlabeled_input_ids, unlabeled_input_mask, unlabeled_segment_ids, unlabeled_label_ids) eval_sampler = SequentialSampler(eval_data) eval_data_loader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) logger.info("***** Tri Training *****") cnt = 0 for cnt in range(1, 3 * epochs + 1): trainset_index = [] model.eval() predict_results_j = [] predict_results_k = [] for input_ids, input_mask, segment_ids, label_ids in tqdm( eval_data_loader): input_ids = input_ids.to(device) input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) label_ids = label_ids.to(device) with torch.no_grad(): logits1, logits2, logits3 = model(input_ids, segment_ids, input_mask) if cnt % 3 == 1: logits_j = logits2 logits_k = logits3 elif cnt % 3 == 2: logits_j = logits1 logits_k = logits3 elif cnt % 3 == 0: logits_j = logits1 logits_k = logits2 logits_j = logits_j.detach().cpu().numpy() logits_k = logits_k.detach().cpu().numpy() predict_results_j.extend(np.argmax(logits_j, axis=1)) predict_results_k.extend(np.argmax(logits_k, axis=1)) # choose p2(x) == p3(x) # print("predict_result_j", predict_results_j) # print("predict_result_k", predict_results_k) for i in range(len(predict_results_j)): if predict_results_j[i] == predict_results_k[i]: trainset_index.append(i) # doing the permutation permutation = np.array(trainset_index) print("permutation size ", len(permutation)) if len(permutation) == 0: train_data_loader = load_train_data(args, input_ids_train, input_mask_train, segment_ids_train, label_ids_train) else: if cnt % 3 == 0: input_ids_train_new = unlabeled_input_ids.numpy()[permutation] input_mask_train_new = unlabeled_input_mask.numpy( )[permutation] segment_ids_train_new = unlabeled_segment_ids.numpy( )[permutation] label_ids_train_new = np.array(predict_results_j)[permutation] train_data_loader = load_train_data(args, input_ids_train_new, input_mask_train_new, segment_ids_train_new, label_ids_train_new) model = train_model(model, args, n_gpu, train_data_loader, device, args.num_student_train_epochs, logger, 3) else: # print("input_ids_train shape:", input_ids_train.shape) # print("unlabeled_input_ids shape", unlabeled_input_ids[permutation].shape) input_ids_train_new = np.concatenate( (input_ids_train, unlabeled_input_ids[permutation]), axis=0) input_mask_train_new = np.concatenate( (input_mask_train, unlabeled_input_mask[permutation]), axis=0) segment_ids_train_new = np.concatenate( (segment_ids_train, unlabeled_segment_ids[permutation]), axis=0) label_ids_train_new = np.concatenate( (label_ids_train, np.array(predict_results_j)[permutation]), axis=0) train_data_loader = load_train_data(args, input_ids_train_new, input_mask_train_new, segment_ids_train_new, label_ids_train_new) model = train_model(model, args, n_gpu, train_data_loader, device, args.num_student_train_epochs, logger, cnt % 3) return model