def main(): parser = get_argument_parser() args = parser.parse_args() check_early_exit_warning(args) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_predict: raise ValueError( "At least one of `do_train` or `do_predict` must be True.") if args.do_train: if not args.train_file: raise ValueError( "If `do_train` is True, then `train_file` must be specified.") if args.do_predict: if not args.predict_file: raise ValueError( "If `do_predict` is True, then `predict_file` must be specified." ) if os.path.exists(args.output_dir) and os.listdir( args.output_dir) and args.do_train: raise ValueError( "Output directory () already exists and is not empty.") os.makedirs(args.output_dir, exist_ok=True) # Prepare Summary writer if torch.distributed.get_rank() == 0 and args.job_name is not None: args.summary_writer = get_summary_writer(name=args.job_name, base=args.output_dir) else: args.summary_writer = None tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) train_examples = None num_train_steps = None if args.do_train: train_examples = read_squad_examples(input_file=args.train_file, is_training=True) num_train_steps = int( len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) # Prepare model #model = BertForQuestionAnswering.from_pretrained(args.bert_model, # cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank)) ## Support for word embedding padding checkpoints # Prepare model bert_model_config = { "vocab_size_or_config_json_file": 119547, "hidden_size": 1024, "num_hidden_layers": 24, "num_attention_heads": 16, "intermediate_size": 4096, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 512, "type_vocab_size": 2, "initializer_range": 0.02 } bert_config = BertConfig(**bert_model_config) bert_config.vocab_size = len(tokenizer.vocab) # Padding for divisibility by 8 if bert_config.vocab_size % 8 != 0: bert_config.vocab_size += 8 - (bert_config.vocab_size % 8) model = BertForQuestionAnswering(bert_config, args) print("VOCAB SIZE:", bert_config.vocab_size) if args.model_file is not "0": logger.info(f"Loading Pretrained Bert Encoder from: {args.model_file}") checkpoint_state_dict = torch.load(args.model_file, map_location=torch.device("cpu")) model.load_state_dict(checkpoint_state_dict['model_state_dict'], strict=False) #bert_state_dict = torch.load(args.model_file) #model.bert.load_state_dict(bert_state_dict, strict=False) logger.info(f"Pretrained Bert Encoder Loaded from: {args.model_file}") model.to(device) # Prepare optimizer param_optimizer = list(model.named_parameters()) # hack to remove pooler, which is not used # thus it produce None grad that break apex param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] t_total = num_train_steps if args.local_rank != -1: t_total = t_total // torch.distributed.get_world_size() if args.fp16: try: from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False) if args.loss_scale == 0: model, optimizer = amp.initialize(model, optimizer, opt_level="O2", keep_batchnorm_fp32=False, loss_scale="dynamic") else: raise NotImplementedError( "dynamic loss scale is only supported in baseline, please set loss_scale=0" ) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) if args.local_rank != -1: try: from apex.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) model = DDP(model) elif n_gpu > 1: model = torch.nn.DataParallel(model) global_step = 0 if args.do_train: cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}'.format( list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length)) train_features = None try: with open(cached_train_features_file, "rb") as reader: train_features = pickle.load(reader) except: train_features = convert_examples_to_features( examples=train_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=True) if args.local_rank == -1 or torch.distributed.get_rank() == 0: logger.info(" Saving train features into cached file %s", cached_train_features_file) with open(cached_train_features_file, "wb") as writer: pickle.dump(train_features, writer) logger.info("***** Running training *****") logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num split examples = %d", len(train_features)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_steps) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_start_positions = torch.tensor( [f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor( [f.end_position for f in train_features], dtype=torch.long) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) if args.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) gradClipper = GradientClipper(max_grad_norm=1.0) model.train() ema_loss = 0. sample_count = 0 num_epoch = 0 for _ in trange(int(args.num_train_epochs), desc="Epoch"): num_epoch += 1 epoch_step = 0 for step, batch in enumerate( tqdm(train_dataloader, desc="Iteration", smoothing=0)): if n_gpu == 1: batch = tuple( t.to(device) for t in batch) # multi-gpu does scattering it-self input_ids, input_mask, segment_ids, start_positions, end_positions = batch loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps ema_loss = args.loss_plot_alpha * ema_loss + ( 1 - args.loss_plot_alpha) * loss.item() if args.local_rank != -1: model.disable_allreduce() if (step + 1) % args.gradient_accumulation_steps == 0: model.enable_allreduce() if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # gradient clipping gradClipper.step(amp.master_params(optimizer)) sample_count += (args.train_batch_size * torch.distributed.get_world_size()) if (step + 1) % args.gradient_accumulation_steps == 0: # modify learning rate with special warm up BERT uses lr_this_step = args.learning_rate * warmup_linear( global_step / t_total, args.warmup_proportion) for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 epoch_step += 1 if torch.distributed.get_rank( ) == 0 and args.summary_writer: summary_events = [ (f'Train/Steps/lr', lr_this_step, global_step), (f'Train/Samples/train_loss', loss, sample_count), (f'Train/Samples/lr', lr_this_step, sample_count), (f'Train/Samples/train_ema_loss', ema_loss, sample_count) ] if args.fp16 and hasattr(optimizer, 'cur_scale'): summary_events.append( (f'Train/Samples/scale', optimizer.cur_scale, sample_count)) write_summary_events(args.summary_writer, summary_events) if torch.distributed.get_rank() == 0 and ( step + 1) % args.print_steps == 0: logger.info( f"bert_squad_progress: step={global_step} lr={lr_this_step} loss={ema_loss}" ) if is_time_to_exit(args=args, epoch_steps=epoch_step, global_steps=global_step): logger.info( f'Warning: Early epoch termination due to max steps limit, epoch step ={epoch_step}, global step = {global_step}, epoch = {num_epoch}' ) break # Save a trained model #model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self #output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") #if args.do_train: # torch.save(model_to_save.state_dict(), output_model_file) # Load a trained model that you have fine-tuned #model_state_dict = torch.load(output_model_file) #model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict) #model.to(device) if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): eval_examples = read_squad_examples(input_file=args.predict_file, is_training=False) eval_features = convert_examples_to_features( examples=eval_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=False) logger.info("***** Running predictions *****") logger.info(" Num orig examples = %d", len(eval_examples)) logger.info(" Num split examples = %d", len(eval_features)) logger.info(" Batch size = %d", args.predict_batch_size) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) # Run prediction for full data eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size) model.eval() all_results = [] logger.info("Start evaluating") for input_ids, input_mask, segment_ids, example_indices in tqdm( eval_dataloader, desc="Evaluating"): if len(all_results) % 1000 == 0: logger.info("Processing example: %d" % (len(all_results))) input_ids = input_ids.to(device) input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) with torch.no_grad(): batch_start_logits, batch_end_logits = model( input_ids, segment_ids, input_mask) for i, example_index in enumerate(example_indices): start_logits = batch_start_logits[i].detach().cpu().tolist() end_logits = batch_end_logits[i].detach().cpu().tolist() eval_feature = eval_features[example_index.item()] unique_id = int(eval_feature.unique_id) all_results.append( RawResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits)) output_prediction_file = os.path.join(args.output_dir, "predictions.json") output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json") write_predictions(eval_examples, eval_features, all_results, args.n_best_size, args.max_answer_length, args.do_lower_case, output_prediction_file, output_nbest_file, args.verbose_logging)
def main(): parser = get_argument_parser() deepspeed.init_distributed(dist_backend='nccl') # Include DeepSpeed configuration arguments parser = deepspeed.add_config_arguments(parser) args = parser.parse_args() args.local_rank = int(os.environ['LOCAL_RANK']) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) train_examples = None num_train_steps = None if args.do_train: train_examples = read_squad_examples(input_file=args.train_file, is_training=True) num_train_steps = int( len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) # Prepare model # model = BertForQuestionAnswering.from_pretrained(args.bert_model, # cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank)) # Support for word embedding padding checkpoints # Prepare model bert_model_config = { "vocab_size_or_config_json_file": 119547, "hidden_size": 1024, "num_hidden_layers": 24, "num_attention_heads": 16, "intermediate_size": 4096, "hidden_act": "gelu", "hidden_dropout_prob": args.dropout, "attention_probs_dropout_prob": args.dropout, "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 512, "type_vocab_size": 2, "initializer_range": 0.02 } if args.ckpt_type == "DS": if args.preln: bert_config = BertConfigPreLN(**bert_model_config) else: bert_config = BertConfig(**bert_model_config) else: # Models from Tensorflow and Huggingface are post-LN. if args.preln: raise ValueError( "Should NOT use --preln if the loading checkpoint doesn't use pre-layer-norm." ) # Use the original bert config if want to load from non-DeepSpeed checkpoint. if args.origin_bert_config_file is None: raise ValueError( "--origin_bert_config_file is required for loading non-DeepSpeed checkpoint." ) bert_config = BertConfig.from_json_file(args.origin_bert_config_file) if bert_config.vocab_size != len(tokenizer.vocab): raise ValueError("vocab size from original checkpoint mismatch.") bert_config.vocab_size = len(tokenizer.vocab) # Padding for divisibility by 8 if bert_config.vocab_size % 8 != 0: vocab_diff = 8 - (bert_config.vocab_size % 8) bert_config.vocab_size += vocab_diff if args.preln: model = BertForQuestionAnsweringPreLN(bert_config, args) else: model = BertForQuestionAnswering(bert_config, args) print("VOCAB SIZE:", bert_config.vocab_size) if args.model_file is not "0": logger.info(f"Loading Pretrained Bert Encoder from: {args.model_file}") if args.ckpt_type == "DS": checkpoint_state_dict = torch.load( args.model_file, map_location=torch.device("cpu")) if 'module' in checkpoint_state_dict: logger.info('Loading DeepSpeed v2.0 style checkpoint') model.load_state_dict(checkpoint_state_dict['module'], strict=False) elif 'model_state_dict' in checkpoint_state_dict: model.load_state_dict( checkpoint_state_dict['model_state_dict'], strict=False) else: raise ValueError("Unable to find model state in checkpoint") else: from convert_bert_ckpt_to_deepspeed import convert_ckpt_to_deepspeed convert_ckpt_to_deepspeed(model, args.ckpt_type, args.model_file, vocab_diff, args.deepspeed_transformer_kernel) logger.info(f"Pretrained Bert Encoder Loaded from: {args.model_file}") # Prepare optimizer param_optimizer = list(model.named_parameters()) # hack to remove pooler, which is not used # thus it produce None grad that break apex param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] if args.deepspeed_transformer_kernel: no_decay = no_decay + [ 'attn_nw', 'attn_nb', 'norm_w', 'norm_b', 'attn_qkvb', 'attn_ob', 'inter_b', 'output_b' ] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] model, optimizer, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=optimizer_grouped_parameters, dist_init_required=True) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs #torch.distributed.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_predict: raise ValueError( "At least one of `do_train` or `do_predict` must be True.") if args.do_train: if not args.train_file: raise ValueError( "If `do_train` is True, then `train_file` must be specified.") if args.do_predict: if not args.predict_file: raise ValueError( "If `do_predict` is True, then `predict_file` must be specified." ) if os.path.exists(args.output_dir) and os.listdir( args.output_dir) and args.do_train: os.makedirs(args.output_dir, exist_ok=True) # Prepare Summary writer if torch.distributed.get_rank() == 0 and args.job_name is not None: args.summary_writer = get_summary_writer(name=args.job_name, base=args.output_dir) else: args.summary_writer = None logger.info("propagate deepspeed-config settings to client settings") args.train_batch_size = model.train_micro_batch_size_per_gpu() args.gradient_accumulation_steps = model.gradient_accumulation_steps() args.fp16 = model.fp16_enabled() args.print_steps = model.steps_per_print() args.learning_rate = model.get_lr()[0] args.wall_clock_breakdown = model.wall_clock_breakdown() t_total = num_train_steps if args.local_rank != -1: t_total = t_total // torch.distributed.get_world_size() global_step = 0 if args.do_train: cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}'.format( list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length)) train_features = None try: with open(cached_train_features_file, "rb") as reader: train_features = pickle.load(reader) except: train_features = convert_examples_to_features( examples=train_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=True) if args.local_rank == -1 or torch.distributed.get_rank() == 0: logger.info(" Saving train features into cached file %s", cached_train_features_file) with open(cached_train_features_file, "wb") as writer: pickle.dump(train_features, writer) logger.info("***** Running training *****") logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num split examples = %d", len(train_features)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_steps) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_start_positions = torch.tensor( [f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor( [f.end_position for f in train_features], dtype=torch.long) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) if args.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) model.train() ema_loss = 0. sample_count = 0 num_epoch = 0 global all_step_time ave_rounds = 20 for _ in trange(int(args.num_train_epochs), desc="Epoch"): num_epoch += 1 epoch_step = 0 for step, batch in enumerate( tqdm(train_dataloader, desc="Iteration", smoothing=0)): start_time = time.time() bs_size = batch[0].size()[0] if n_gpu == 1: batch = tuple( t.to(device) for t in batch) # multi-gpu does scattering it-self input_ids, input_mask, segment_ids, start_positions, end_positions = batch loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps ema_loss = args.loss_plot_alpha * ema_loss + ( 1 - args.loss_plot_alpha) * loss.item() model.backward(loss) loss_item = loss.item() * args.gradient_accumulation_steps loss = None sample_count += (args.train_batch_size * torch.distributed.get_world_size()) if (step + 1) % args.gradient_accumulation_steps == 0: # modify learning rate with special warm up BERT uses lr_this_step = args.learning_rate * warmup_linear( global_step / t_total, args.warmup_proportion) for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step model.step() global_step += 1 epoch_step += 1 if torch.distributed.get_rank( ) == 0 and args.summary_writer: summary_events = [ (f'Train/Steps/lr', lr_this_step, global_step), (f'Train/Samples/train_loss', loss_item, sample_count), (f'Train/Samples/lr', lr_this_step, sample_count), (f'Train/Samples/train_ema_loss', ema_loss, sample_count) ] if args.fp16 and hasattr(optimizer, 'cur_scale'): summary_events.append( (f'Train/Samples/scale', optimizer.cur_scale, sample_count)) write_summary_events(args.summary_writer, summary_events) args.summary_writer.flush() if torch.distributed.get_rank() == 0 and ( step + 1) % args.print_steps == 0: logger.info( f"bert_squad_progress: step={global_step} lr={lr_this_step} loss={ema_loss}" ) else: model.step() if is_time_to_exit(args=args, epoch_steps=epoch_step, global_steps=global_step): logger.info( f'Warning: Early epoch termination due to max steps limit, epoch step ={epoch_step}, global step = {global_step}, epoch = {num_epoch}' ) break one_step_time = time.time() - start_time all_step_time += one_step_time if (step + 1) % ( ave_rounds) == 0 and torch.distributed.get_rank() == 0: print( ' At step {}, averaged throughput for {} rounds is: {} Samples/s' .format( step, ave_rounds, bs_size * ave_rounds * torch.distributed.get_world_size() / all_step_time), flush=True) all_step_time = 0.0 # Save a trained model # model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self #output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") # if args.do_train: # torch.save(model_to_save.state_dict(), output_model_file) # Load a trained model that you have fine-tuned #model_state_dict = torch.load(output_model_file) #model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict) # model.to(device) if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): eval_examples = read_squad_examples(input_file=args.predict_file, is_training=False) eval_features = convert_examples_to_features( examples=eval_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=False) logger.info("***** Running predictions *****") logger.info(" Num orig examples = %d", len(eval_examples)) logger.info(" Num split examples = %d", len(eval_features)) logger.info(" Batch size = %d", args.predict_batch_size) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) # Run prediction for full data eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size) model.eval() all_results = [] logger.info("Start evaluating") for input_ids, input_mask, segment_ids, example_indices in tqdm( eval_dataloader, desc="Evaluating"): if len(all_results) % 1000 == 0: logger.info("Processing example: %d" % (len(all_results))) input_ids = input_ids.to(device) input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) with torch.no_grad(): batch_start_logits, batch_end_logits = model( input_ids, segment_ids, input_mask) for i, example_index in enumerate(example_indices): start_logits = batch_start_logits[i].detach().cpu().tolist() end_logits = batch_end_logits[i].detach().cpu().tolist() eval_feature = eval_features[example_index.item()] unique_id = int(eval_feature.unique_id) all_results.append( RawResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits)) output_prediction_file = os.path.join(args.output_dir, "predictions.json") output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json") write_predictions(eval_examples, eval_features, all_results, args.n_best_size, args.max_answer_length, args.do_lower_case, output_prediction_file, output_nbest_file, args.verbose_logging)