def main(): parser = argparse.ArgumentParser() ## Required parameters parser.add_argument('--mode', type=str, default='train') parser.add_argument('--pause', type=int, default=0) parser.add_argument('--iteration', type=str, default='1') parser.add_argument('--fs', type=str, default='local', help='must be `local`. Do not change.') # Data paths parser.add_argument('--data_dir', default='data/', type=str) parser.add_argument("--train_file", default='train-v1.1.json', type=str, help="SQuAD json for training. E.g., train-v1.1.json") parser.add_argument("--predict_file", default='dev-v1.1.json', type=str, help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") parser.add_argument('--gt_file', default='dev-v1.1.json', type=str, help='ground truth file needed for evaluation.') # Metadata paths parser.add_argument('--metadata_dir', default='metadata/', type=str) parser.add_argument("--vocab_file", default='vocab.txt', type=str, help="The vocabulary file that the BERT model was trained on.") parser.add_argument("--bert_model_option", default='large_uncased', type=str, help="model architecture option. [large_uncased] or [base_uncased]") parser.add_argument("--bert_config_file", default='bert_config.json', type=str, help="The config json file corresponding to the pre-trained BERT model. " "This specifies the model architecture.") parser.add_argument("--init_checkpoint", default='pytorch_model.bin', type=str, help="Initial checkpoint (usually from a pre-trained BERT model).") # Output and load paths parser.add_argument("--output_dir", default='out/', type=str, help="The output directory where the model checkpoints will be written.") parser.add_argument("--index_file", default='index.hdf5', type=str, help="index output file.") parser.add_argument("--question_emb_file", default='question.hdf5', type=str, help="question output file.") parser.add_argument('--load_dir', default='out/', type=str) # Local paths (if we want to run cmd) parser.add_argument('--eval_script', default='evaluate-v1.1.py', type=str) # Do's parser.add_argument("--do_load", default=False, action='store_true', help='Do load. If eval, do load automatically') parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.") parser.add_argument("--do_train_filter", default=False, action='store_true', help='Train filter or not.') parser.add_argument("--do_train_sparse", default=False, action='store_true', help='Train sparse or not.') parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.") parser.add_argument('--do_eval', default=False, action='store_true') parser.add_argument('--do_embed_question', default=False, action='store_true') parser.add_argument('--do_index', default=False, action='store_true') parser.add_argument('--do_serve', default=False, action='store_true') # Model options: if you change these, you need to train again parser.add_argument("--do_case", default=False, action='store_true', help="Whether to lower case the input text. Should be True for uncased " "models and False for cased models.") parser.add_argument('--phrase_size', default=961, type=int) parser.add_argument('--metric', default='ip', type=str, help='ip | l2') parser.add_argument("--use_sparse", default=False, action='store_true') # GPU and memory related options parser.add_argument("--max_seq_length", default=384, type=int, help="The maximum total input sequence length after WordPiece tokenization. Sequences " "longer than this will be truncated, and sequences shorter than this will be padded.") parser.add_argument("--doc_stride", default=128, type=int, help="When splitting up a long document into chunks, how much stride to take between chunks.") parser.add_argument("--max_query_length", default=64, type=int, help="The maximum number of tokens for the question. Questions longer than this will " "be truncated to this length.") parser.add_argument("--train_batch_size", default=12, type=int, help="Total batch size for training.") parser.add_argument("--predict_batch_size", default=16, type=int, help="Total batch size for predictions.") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument('--optimize_on_cpu', default=False, action='store_true', help="Whether to perform optimization and keep the optimizer averages on CPU") parser.add_argument("--no_cuda", default=False, action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--fp16', default=False, action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") # Training options: only effective during training parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--num_train_filter_epochs", default=1.0, type=float, help="Total number of training epochs for filter to perform.") parser.add_argument("--num_train_sparse_epochs", default=3.0, type=float, help="Total number of training epochs for sparse to perform.") parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% " "of training.") parser.add_argument("--save_checkpoints_steps", default=1000, type=int, help="How often to save the model checkpoint.") parser.add_argument("--iterations_per_loop", default=1000, type=int, help="How many steps to make in each estimator call.") # Prediction options: only effective during prediction parser.add_argument("--n_best_size", default=20, type=int, help="The total number of n-best predictions to generate in the nbest_predictions.json " "output file.") parser.add_argument("--max_answer_length", default=30, type=int, help="The maximum length of an answer that can be generated. This is needed because the start " "and end predictions are not conditioned on one another.") # Index Options parser.add_argument('--dtype', default='float32', type=str) parser.add_argument('--filter_threshold', default=-1e9, type=float) parser.add_argument('--compression_offset', default=-2, type=float) parser.add_argument('--compression_scale', default=20, type=float) parser.add_argument('--split_by_para', default=False, action='store_true') # Serve Options parser.add_argument('--port', default=9009, type=int) # Others parser.add_argument('--parallel', default=False, action='store_true') parser.add_argument("--verbose_logging", default=False, action='store_true', help="If true, all of the warnings related to data processing will be printed. " "A number of warnings are expected for a normal SQuAD evaluation.") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--draft', default=False, action='store_true') parser.add_argument('--draft_num_examples', type=int, default=12) args = parser.parse_args() # Filesystem routines if args.fs == 'local': class Processor(object): def __init__(self, path): self._save = None self._load = None self._path = path def bind(self, save, load): self._save = save self._load = load def save(self, checkpoint=None, save_fn=None, **kwargs): path = os.path.join(self._path, str(checkpoint)) if save_fn is None: self._save(path, **kwargs) else: save_fn(path, **kwargs) def load(self, checkpoint, load_fn=None, session=None, **kwargs): assert self._path == session path = os.path.join(self._path, str(checkpoint), 'model.pt') if load_fn is None: self._load(path, **kwargs) else: load_fn(path, **kwargs) processor = Processor(args.load_dir) else: raise ValueError(args.fs) if not args.do_train: args.do_load = True # Configure paths args.train_file = os.path.join(args.data_dir, args.train_file) args.predict_file = os.path.join(args.data_dir, args.predict_file) args.gt_file = os.path.join(args.data_dir, args.gt_file) args.bert_config_file = os.path.join(args.metadata_dir, args.bert_config_file.replace(".json", "") + "_" + args.bert_model_option + ".json") args.init_checkpoint = os.path.join(args.metadata_dir, args.init_checkpoint.replace(".bin", "") + "_" + args.bert_model_option + ".bin") args.vocab_file = os.path.join(args.metadata_dir, args.vocab_file) args.index_file = os.path.join(args.output_dir, args.index_file) # Multi-GPU stuff if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) # Seed for reproducibility random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) bert_config = BertConfig.from_json_file(args.bert_config_file) if args.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (args.max_seq_length, bert_config.max_position_embeddings)) if os.path.exists(args.output_dir) and os.listdir(args.output_dir): # raise ValueError("Output directory () already exists and is not empty.") pass else: os.makedirs(args.output_dir, exist_ok=True) tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=not args.do_case) model = BertPhraseModel( bert_config, phrase_size=args.phrase_size, metric=args.metric, use_sparse=args.use_sparse ) print('Number of model parameters:', sum(p.numel() for p in model.parameters())) if not args.do_load and args.init_checkpoint is not None: state_dict = torch.load(args.init_checkpoint, map_location='cpu') # If below: for Korean BERT compatibility if next(iter(state_dict)).startswith('bert.'): state_dict = {key[len('bert.'):]: val for key, val in state_dict.items()} state_dict = {key: val for key, val in state_dict.items() if key in model.encoder.bert_model.state_dict()} model.encoder.bert.load_state_dict(state_dict) if args.fp16: model.half() if not args.optimize_on_cpu: model.to(device) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) elif args.parallel or n_gpu > 1: model = torch.nn.DataParallel(model) if args.do_load: bind_model(processor, model) processor.load(args.iteration, session=args.load_dir) if args.do_train: train_examples = read_squad_examples( input_file=args.train_file, is_training=True, draft=args.draft, draft_num_examples=args.draft_num_examples) num_train_steps = int( len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) no_decay = ['bias', 'gamma', 'beta'] optimizer_parameters = [ {'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01}, {'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0} ] optimizer = BERTAdam(optimizer_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps) bind_model(processor, model, optimizer) global_step = 0 train_features, train_features_ = convert_examples_to_features( examples=train_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=True) train_features = inject_noise_to_features_list(train_features, clamp=True, replace=True, shuffle=True) logger.info("***** Running training *****") logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num split examples = %d", len(train_features)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_steps) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) all_input_ids_ = torch.tensor([f.input_ids for f in train_features_], dtype=torch.long) all_input_mask_ = torch.tensor([f.input_mask for f in train_features_], dtype=torch.long) if args.fp16: (all_input_ids, all_input_mask, all_start_positions, all_end_positions) = tuple(t.half() for t in (all_input_ids, all_input_mask, all_start_positions, all_end_positions)) all_input_ids_, all_input_mask_ = tuple(t.half() for t in (all_input_ids_, all_input_mask_)) train_data = TensorDataset(all_input_ids, all_input_mask, all_input_ids_, all_input_mask_, all_start_positions, all_end_positions) if args.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) model.train() for epoch in range(int(args.num_train_epochs)): for step, batch in enumerate(tqdm(train_dataloader, desc="Epoch %d" % (epoch + 1))): batch = tuple(t.to(device) for t in batch) (input_ids, input_mask, input_ids_, input_mask_, start_positions, end_positions) = batch loss, _ = model(input_ids, input_mask, input_ids_, input_mask_, start_positions, end_positions) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: if args.optimize_on_cpu: model.to('cpu') optimizer.step() # We have accumulated enought gradients model.zero_grad() if args.optimize_on_cpu: model.to(device) global_step += 1 processor.save(epoch + 1) if args.do_train_filter: train_examples = read_squad_examples( input_file=args.train_file, is_training=True, draft=args.draft, draft_num_examples=args.draft_num_examples) num_train_steps = int( len( train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_filter_epochs) if args.parallel or n_gpu > 1: optimizer = Adam(model.module.filter.parameters()) else: optimizer = Adam(model.filter.parameters()) bind_model(processor, model, optimizer) global_step = 0 train_features, train_features_ = convert_examples_to_features( examples=train_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=True) logger.info("***** Running filter training *****") logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num split examples = %d", len(train_features)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_steps) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) all_input_ids_ = torch.tensor([f.input_ids for f in train_features_], dtype=torch.long) all_input_mask_ = torch.tensor([f.input_mask for f in train_features_], dtype=torch.long) if args.fp16: (all_input_ids, all_input_mask, all_start_positions, all_end_positions) = tuple(t.half() for t in (all_input_ids, all_input_mask, all_start_positions, all_end_positions)) all_input_ids_, all_input_mask_ = tuple(t.half() for t in (all_input_ids_, all_input_mask_)) train_data = TensorDataset(all_input_ids, all_input_mask, all_input_ids_, all_input_mask_, all_start_positions, all_end_positions) if args.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) model.train() for epoch in range(int(args.num_train_filter_epochs)): for step, batch in enumerate(tqdm(train_dataloader, desc="Epoch %d" % (epoch + 1))): batch = tuple(t.to(device) for t in batch) (input_ids, input_mask, input_ids_, input_mask_, start_positions, end_positions) = batch _, loss = model(input_ids, input_mask, input_ids_, input_mask_, start_positions, end_positions) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: if args.optimize_on_cpu: model.to('cpu') optimizer.step() # We have accumulated enought gradients model.zero_grad() if args.optimize_on_cpu: model.to(device) global_step += 1 processor.save(epoch + 1) if args.do_train_sparse: train_examples = read_squad_examples( input_file=args.train_file, is_training=True, draft=args.draft, draft_num_examples=args.draft_num_examples) num_train_steps = int( len( train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_sparse_epochs) ''' if args.parallel or n_gpu > 1: optimizer = Adam(model.module.sparse_layer.parameters()) else: optimizer = Adam(model.sparse_layer.parameters()) ''' no_decay = ['bias', 'gamma', 'beta'] optimizer_parameters = [ {'params': [p for n, p in model.named_parameters() if (n not in no_decay) and ('filter' not in n)], 'weight_decay_rate': 0.01}, {'params': [p for n, p in model.named_parameters() if (n in no_decay) and ('filter' not in n)], 'weight_decay_rate': 0.0} ] optimizer = BERTAdam(optimizer_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps) bind_model(processor, model, optimizer) global_step = 0 train_features, train_features_ = convert_examples_to_features( examples=train_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=True) logger.info("***** Running sparse training *****") logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num split examples = %d", len(train_features)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_steps) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) all_input_ids_ = torch.tensor([f.input_ids for f in train_features_], dtype=torch.long) all_input_mask_ = torch.tensor([f.input_mask for f in train_features_], dtype=torch.long) if args.fp16: (all_input_ids, all_input_mask, all_start_positions, all_end_positions) = tuple(t.half() for t in (all_input_ids, all_input_mask, all_start_positions, all_end_positions)) all_input_ids_, all_input_mask_ = tuple(t.half() for t in (all_input_ids_, all_input_mask_)) train_data = TensorDataset(all_input_ids, all_input_mask, all_input_ids_, all_input_mask_, all_start_positions, all_end_positions) if args.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) model.train() for epoch in range(int(args.num_train_sparse_epochs)): for step, batch in enumerate(tqdm(train_dataloader, desc="Epoch %d" % (epoch + 1))): batch = tuple(t.to(device) for t in batch) (input_ids, input_mask, input_ids_, input_mask_, start_positions, end_positions) = batch loss, _ = model(input_ids, input_mask, input_ids_, input_mask_, start_positions, end_positions) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: if args.optimize_on_cpu: model.to('cpu') optimizer.step() # We have accumulated enought gradients model.zero_grad() if args.optimize_on_cpu: model.to(device) global_step += 1 processor.save(epoch + 1) if args.do_predict: eval_examples = read_squad_examples( input_file=args.predict_file, is_training=False, draft=args.draft, draft_num_examples=args.draft_num_examples) eval_features, query_eval_features = convert_examples_to_features( examples=eval_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=False) logger.info("***** Running predictions *****") logger.info(" Num orig examples = %d", len(eval_examples)) logger.info(" Num split examples = %d", len(eval_features)) logger.info(" Batch size = %d", args.predict_batch_size) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_input_ids_ = torch.tensor([f.input_ids for f in query_eval_features], dtype=torch.long) all_input_mask_ = torch.tensor([f.input_mask for f in query_eval_features], dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) if args.fp16: (all_input_ids, all_input_mask, all_example_index) = tuple(t.half() for t in (all_input_ids, all_input_mask, all_example_index)) all_input_ids_, all_input_mask_ = tuple(t.half() for t in (all_input_ids_, all_input_mask_)) eval_data = TensorDataset(all_input_ids, all_input_mask, all_input_ids_, all_input_mask_, all_example_index) if args.local_rank == -1: eval_sampler = SequentialSampler(eval_data) else: eval_sampler = DistributedSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size) model.eval() logger.info("Start evaluating") def get_results(): for (input_ids, input_mask, input_ids_, input_mask_, example_indices) in eval_dataloader: input_ids = input_ids.to(device) input_mask = input_mask.to(device) input_ids_ = input_ids_.to(device) input_mask_ = input_mask_.to(device) with torch.no_grad(): batch_all_logits, bs, be = model(input_ids, input_mask, input_ids_, input_mask_) for i, example_index in enumerate(example_indices): all_logits = batch_all_logits[i].detach().cpu().numpy() filter_start_logits = bs[i].detach().cpu().numpy() filter_end_logits = be[i].detach().cpu().numpy() eval_feature = eval_features[example_index.item()] unique_id = int(eval_feature.unique_id) yield RawResult(unique_id=unique_id, all_logits=all_logits, filter_start_logits=filter_start_logits, filter_end_logits=filter_end_logits) output_prediction_file = os.path.join(args.output_dir, "predictions.json") write_predictions(eval_examples, eval_features, get_results(), args.max_answer_length, not args.do_case, output_prediction_file, args.verbose_logging, args.filter_threshold) if args.do_eval: command = "python %s %s %s" % (args.eval_script, args.gt_file, output_prediction_file) import subprocess process = subprocess.Popen(command.split(), stdout=subprocess.PIPE) output, error = process.communicate() if args.do_embed_question: question_examples = read_squad_examples( question_only=True, input_file=args.predict_file, is_training=False, draft=args.draft, draft_num_examples=args.draft_num_examples) query_eval_features = convert_questions_to_features( examples=question_examples, tokenizer=tokenizer, max_query_length=args.max_query_length) question_dataloader = convert_question_features_to_dataloader(query_eval_features, args.fp16, args.local_rank, args.predict_batch_size) model.eval() logger.info("Start embedding") question_results = get_question_results_(question_examples, query_eval_features, question_dataloader, device, model) path = os.path.join(args.output_dir, args.question_emb_file) print('Writing %s' % path) write_question_results(question_results, query_eval_features, path) if args.do_index: if ':' not in args.predict_file: predict_files = [args.predict_file] offsets = [0] else: dirname = os.path.dirname(args.predict_file) basename = os.path.basename(args.predict_file) start, end = list(map(int, basename.split(':'))) # skip files if possible if os.path.exists(args.index_file): with h5py.File(args.index_file, 'r') as f: dids = list(map(int, f.keys())) start = int(max(dids) / 1000) print('%s exists; starting from %d' % (args.index_file, start)) names = [str(i).zfill(4) for i in range(start, end)] predict_files = [os.path.join(dirname, name) for name in names] offsets = [int(each) * 1000 for each in names] for offset, predict_file in zip(offsets, predict_files): try: context_examples = read_squad_examples( context_only=True, input_file=predict_file, is_training=False, draft=args.draft, draft_num_examples=args.draft_num_examples) for example in context_examples: example.doc_idx += offset context_features = convert_documents_to_features( examples=context_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride) logger.info("***** Running indexing on %s *****" % predict_file) logger.info(" Num orig examples = %d", len(context_examples)) logger.info(" Num split examples = %d", len(context_features)) logger.info(" Batch size = %d", args.predict_batch_size) all_input_ids = torch.tensor([f.input_ids for f in context_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in context_features], dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) if args.fp16: all_input_ids, all_input_mask, all_example_index = tuple( t.half() for t in (all_input_ids, all_input_mask, all_example_index)) context_data = TensorDataset(all_input_ids, all_input_mask, all_example_index) if args.local_rank == -1: context_sampler = SequentialSampler(context_data) else: context_sampler = DistributedSampler(context_data) context_dataloader = DataLoader(context_data, sampler=context_sampler, batch_size=args.predict_batch_size) model.eval() logger.info("Start indexing") def get_context_results(): for (input_ids, input_mask, example_indices) in context_dataloader: input_ids = input_ids.to(device) input_mask = input_mask.to(device) with torch.no_grad(): batch_start, batch_end, batch_span_logits, bs, be, batch_sparse = model(input_ids, input_mask) for i, example_index in enumerate(example_indices): start = batch_start[i].detach().cpu().numpy().astype(args.dtype) end = batch_end[i].detach().cpu().numpy().astype(args.dtype) sparse = None if batch_sparse is not None: sparse = batch_sparse[i].detach().cpu().numpy().astype(args.dtype) span_logits = batch_span_logits[i].detach().cpu().numpy().astype(args.dtype) filter_start_logits = bs[i].detach().cpu().numpy().astype(args.dtype) filter_end_logits = be[i].detach().cpu().numpy().astype(args.dtype) context_feature = context_features[example_index.item()] unique_id = int(context_feature.unique_id) yield ContextResult(unique_id=unique_id, start=start, end=end, span_logits=span_logits, filter_start_logits=filter_start_logits, filter_end_logits=filter_end_logits, sparse=sparse) t0 = time() write_hdf5(context_examples, context_features, get_context_results(), args.max_answer_length, not args.do_case, args.index_file, args.filter_threshold, args.verbose_logging, offset=args.compression_offset, scale=args.compression_scale, split_by_para=args.split_by_para, use_sparse=args.use_sparse) print('%s: %.1f mins' % (predict_file, (time() - t0) / 60)) except Exception as e: with open(os.path.join(args.output_dir, 'error_files.txt'), 'a') as fp: fp.write('error file: %s\n' % predict_file) fp.write('error message: %s\n' % str(e)) if args.do_serve: def get(text): question_examples = [SquadExample(qas_id='serve', question_text=text)] query_eval_features = convert_questions_to_features( examples=question_examples, tokenizer=tokenizer, max_query_length=16) question_dataloader = convert_question_features_to_dataloader(query_eval_features, args.fp16, args.local_rank, args.predict_batch_size) model.eval() question_results = get_question_results_(question_examples, query_eval_features, question_dataloader, device, model) question_result = next(iter(question_results)) out = question_result.start.tolist(), question_result.end.tolist(), question_result.span_logit.tolist() return out serve(get, args.port)
def main(): parser = argparse.ArgumentParser() # Data paths parser.add_argument('--data_dir', default='data/', type=str) parser.add_argument("--predict_file", default='dev-v1.1.json', type=str, help="json for prediction.") # Metadata paths parser.add_argument('--metadata_dir', default='models/bert', type=str, help="Dir for pre-trained models.") parser.add_argument("--vocab_file", default='vocab.txt', type=str, help="Vocabulary file of pre-trained model.") parser.add_argument( "--bert_model_option", default='large_uncased', type=str, help="model architecture option. [large_uncased] or [base_uncased].") parser.add_argument( "--bert_config_file", default='bert_config.json', type=str, help="The config json file corresponding to the pre-trained BERT model." ) parser.add_argument( "--init_checkpoint", default='pytorch_model.bin', type=str, help="Initial checkpoint (usually from a pre-trained BERT model).") # Output and load paths parser.add_argument("--output_dir", default='out/', type=str, help="storing models and predictions") parser.add_argument("--dump_dir", default='test/', type=str) parser.add_argument("--dump_file", default='phrase.hdf5', type=str, help="dump phrases of file.") parser.add_argument('--load_dir', default='out/', type=str, help="Dir for checkpoints of models to load.") parser.add_argument('--load_epoch', type=str, default='1', help="Epoch of model to load.") # Do's parser.add_argument("--do_load", default=False, action='store_true', help='Do load. If eval, do load automatically') parser.add_argument('--do_dump', default=False, action='store_true') # Model options: if you change these, you need to train again parser.add_argument("--do_case", default=False, action='store_true', help="Whether to keep upper casing") parser.add_argument("--use_sparse", default=False, action='store_true') parser.add_argument("--sparse_ngrams", default='1,2', type=str) parser.add_argument("--skip_no_answer", default=False, action='store_true') parser.add_argument('--freeze_word_emb', default=False, action='store_true') parser.add_argument('--append_title', default=False, action='store_true') # GPU and memory related options parser.add_argument( "--max_seq_length", default=384, type=int, help= "The maximum total input sequence length after WordPiece tokenization. Sequences " "longer than this will be truncated, and sequences shorter than this will be padded." ) parser.add_argument( "--doc_stride", default=128, type=int, help= "When splitting up a long document into chunks, how much stride to take between chunks." ) parser.add_argument("--predict_batch_size", default=64, type=int, help="Total batch size for predictions.") parser.add_argument("--no_cuda", default=False, action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--parallel', default=False, action='store_true') parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") # Prediction options: only effective during prediction parser.add_argument( "--n_best_size", default=20, type=int, help= "The total number of n-best predictions to generate in the nbest_predictions.json " "output file.") parser.add_argument( "--max_answer_length", default=30, type=int, help= "The maximum length of an answer that can be generated. This is needed because the start " "and end predictions are not conditioned on one another.") # Index Options parser.add_argument('--dtype', default='float32', type=str) parser.add_argument('--filter_threshold', default=-1e9, type=float) parser.add_argument('--dense_offset', default=-2, type=float) # Original parser.add_argument('--dense_scale', default=20, type=float) parser.add_argument('--sparse_offset', default=1.6, type=float) parser.add_argument('--sparse_scale', default=80, type=float) # Others parser.add_argument( "--verbose_logging", default=False, action='store_true', help= "If true, all of the warnings related to data processing will be printed. " "A number of warnings are expected for a normal SQuAD evaluation.") parser.add_argument('--seed', type=int, default=45, help="random seed for initialization") parser.add_argument('--draft', default=False, action='store_true') parser.add_argument('--draft_num_examples', type=int, default=12) args = parser.parse_args() # Filesystem routines class Processor(object): def __init__(self, save_path, load_path): self._save = None self._load = None self._save_path = save_path self._load_path = load_path def bind(self, save, load): self._save = save self._load = load def save(self, checkpoint=None, save_fn=None, **kwargs): path = os.path.join(self._save_path, str(checkpoint)) if save_fn is None: self._save(path, **kwargs) else: save_fn(path, **kwargs) def load(self, checkpoint, load_fn=None, session=None, **kwargs): assert self._load_path == session path = os.path.join(self._load_path, str(checkpoint), 'model.pt') if load_fn is None: self._load(path, **kwargs) else: load_fn(path, **kwargs) processor = Processor(args.output_dir, args.load_dir) if args.do_load is False: logger.info("Setting do_load to true for dumping") args.do_load = True # Configure file paths args.predict_file = os.path.join(args.data_dir, args.predict_file) args.vocab_file = os.path.join(args.metadata_dir, args.vocab_file) args.bert_config_file = os.path.join( args.metadata_dir, args.bert_config_file.replace(".json", "") + "_" + args.bert_model_option + ".json") args.init_checkpoint = os.path.join( args.metadata_dir, args.init_checkpoint.replace(".bin", "") + "_" + args.bert_model_option + ".bin") args.dump_file = os.path.join(args.dump_dir, args.dump_file) # CUDA Check logger.info('cuda availability: {}'.format(torch.cuda.is_available())) # Multi-GPU stuff if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) # Seed for reproducibility random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # bert_config = BertConfig.from_json_file(args.bert_config_file) bert_config = AutoConfig.from_pretrained( 'bert-base-uncased' if not (args.bert_model_option == 'large_uncased') else 'bert-large-uncased', cache_dir='cache', ) if args.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (args.max_seq_length, bert_config.max_position_embeddings)) if os.path.exists(args.output_dir) and os.listdir(args.output_dir): logger.info("Overwriting outputs in %s" % args.output_dir) else: os.makedirs(args.output_dir, exist_ok=True) if os.path.exists(args.dump_dir) and os.listdir(args.dump_dir): logger.info("Overwriting dump in %s" % args.dump_dir) else: os.makedirs(args.dump_dir, exist_ok=True) model = DenSPI( bert_config, sparse_ngrams=args.sparse_ngrams.split(','), use_sparse=args.use_sparse, ) logger.info('Number of model parameters: {:,}'.format( sum(p.numel() for p in model.parameters()))) tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=not args.do_case) # Initialize BERT if not loading and has init_checkpoint if not args.do_load and args.init_checkpoint is not None: if args.draft: logger.info('[Draft] Randomly initialized model') else: state_dict = torch.load(args.init_checkpoint, map_location='cpu') if next(iter(state_dict)).startswith('bert.'): state_dict = { key[len('bert.'):]: val for key, val in state_dict.items() } state_dict = { key: val for key, val in state_dict.items() if key in model.bert.state_dict() } check_diff(model.bert.state_dict(), state_dict) model.bert.load_state_dict(state_dict) logger.info('Model initialized from the pre-trained BERT weight!') ''' if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) elif args.parallel or n_gpu > 1: model = torch.nn.DataParallel(model) logger.info("Data parallel!") ''' if args.do_load: bind_model(processor, model) # processor.load(args.load_epoch, session=args.load_dir) model = DenSPI.from_pretrained( args.load_dir, config=bert_config, sparse_ngrams=args.sparse_ngrams.split(','), use_sparse=args.use_sparse, ) model.to(device) def is_freeze_param(name): if args.freeze_word_emb: if name.endswith("bert.embeddings.word_embeddings.weight"): logger.info(f'freezeing {name}') return False return True # Dump phrases if args.do_dump: if ':' not in args.predict_file: predict_files = [args.predict_file] offsets = [0] else: dirname = os.path.dirname(args.predict_file) basename = os.path.basename(args.predict_file) start, end = list(map(int, basename.split(':'))) # skip files if possible if os.path.exists(args.dump_file): with h5py.File(args.dump_file, 'r') as f: dids = list(map(int, f.keys())) start = int(max(dids) / 1000) logger.info('%s exists; starting from %d' % (args.dump_file, start)) names = [str(i).zfill(4) for i in range(start, end)] predict_files = [os.path.join(dirname, name) for name in names] offsets = [int(each) * 1000 for each in names] for offset, predict_file in zip(offsets, predict_files): context_examples = read_squad_examples( context_only=True, input_file=predict_file, return_answers=False, draft=args.draft, draft_num_examples=args.draft_num_examples, append_title=args.append_title) for example in context_examples: example.doc_idx += offset context_features = convert_documents_to_features( examples=context_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride) logger.info("***** Running dumping on %s *****" % predict_file) logger.info(" Num orig examples = %d", len(context_examples)) logger.info(" Num split examples = %d", len(context_features)) logger.info(" Batch size = %d", args.predict_batch_size) all_input_ids = torch.tensor( [f.input_ids for f in context_features], dtype=torch.long) all_input_mask = torch.tensor( [f.input_mask for f in context_features], dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) context_data = TensorDataset(all_input_ids, all_input_mask, all_example_index) if args.local_rank == -1: context_sampler = SequentialSampler(context_data) else: context_sampler = DistributedSampler(context_data) context_dataloader = DataLoader(context_data, sampler=context_sampler, batch_size=args.predict_batch_size) model.eval() logger.info("Start dumping") def get_context_results(): for (input_ids, input_mask, example_indices) in context_dataloader: input_ids = input_ids.to(device) input_mask = input_mask.to(device) with torch.no_grad(): batch_start, batch_end, batch_span_logits, batch_filter_start, batch_filter_end, sp_s, sp_e = model( input_ids=input_ids, input_mask=input_mask) for i, example_index in enumerate(example_indices): start = batch_start[i].detach().cpu().numpy().astype( args.dtype) end = batch_end[i].detach().cpu().numpy().astype( args.dtype) sparse = None if len(sp_s) > 0: b_ssp = { ng: bb_ssp[i].detach().cpu().numpy().astype( args.dtype) for ng, bb_ssp in sp_s.items() } b_esp = { ng: bb_esp[i].detach().cpu().numpy().astype( args.dtype) for ng, bb_esp in sp_e.items() } span_logits = batch_span_logits[i].detach().cpu( ).numpy().astype(args.dtype) filter_start_logits = batch_filter_start[i].detach( ).cpu().numpy().astype(args.dtype) filter_end_logits = batch_filter_end[i].detach().cpu( ).numpy().astype(args.dtype) context_feature = context_features[ example_index.item()] unique_id = int(context_feature.unique_id) yield ContextResult( unique_id=unique_id, start=start, end=end, span_logits=span_logits, filter_start_logits=filter_start_logits, filter_end_logits=filter_end_logits, start_sp=b_ssp, end_sp=b_esp) t0 = time() write_hdf5(context_examples, context_features, get_context_results(), args.max_answer_length, not args.do_case, args.dump_file, args.filter_threshold, args.verbose_logging, dense_offset=args.dense_offset, dense_scale=args.dense_scale, sparse_offset=args.sparse_offset, sparse_scale=args.sparse_scale, use_sparse=args.use_sparse) logger.info('%s: %.1f mins' % (predict_file, (time() - t0) / 60))