def test(args, test_dataset, model, tokenizer, predict_file): args.test_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) test_sampler = SequentialSampler(test_dataset) cache_file = predict_file test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=args.test_batch_size, num_workers=args.num_workers) cls_token_id, sep_token_id, pad_token_id, mask_token_id, period_token_id = \ tokenizer.convert_tokens_to_ids( [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token, tokenizer.mask_token, '.'] ) model.eval() def gen_rows(): time_meter = 0 # restore existing results for long running inference tasks exist_key2pred = {} tmp_file = cache_file + '.tmp.copy' if op.isfile(tmp_file): with open(tmp_file, 'r') as fp: for line in fp: parts = line.strip().split('\t') if len(parts) == 2: exist_key2pred[parts[0]] = parts[1] with torch.no_grad(): for step, (img_keys, batch) in tqdm(enumerate(test_dataloader)): is_exist = True for k in img_keys: if k not in exist_key2pred: is_exist = False break if is_exist: for k in img_keys: yield k, exist_key2pred[k] continue batch = tuple(t.to(args.device) for t in batch) inputs = { 'is_decode': True, 'input_ids': batch[0], 'attention_mask': batch[1], 'token_type_ids': batch[2], 'img_feats': batch[3], 'masked_pos': batch[4], 'do_sample': False, 'bos_token_id': cls_token_id, 'pad_token_id': pad_token_id, 'eos_token_ids': [sep_token_id, pad_token_id], 'mask_token_id': mask_token_id, # for adding od labels 'add_od_labels': args.add_od_labels, 'od_labels_start_posid': args.max_seq_a_length, # for disable image features #'disable_img_features': args.disable_img_features, # for selecting top object tags at inference #'keep_top_percentage_tag_conf_threshold': args.keep_top_percentage_tag_conf_threshold, #'keep_top_percentage_tag': args.keep_top_percentage_tag, # hyperparameters of beam search 'max_length': args.max_gen_length, 'num_beams': args.num_beams, "temperature": args.temperature, "top_k": args.top_k, "top_p": args.top_p, "repetition_penalty": args.repetition_penalty, "length_penalty": args.length_penalty, "num_return_sequences": args.num_return_sequences, "num_keep_best": args.num_keep_best, } if args.use_cbs: inputs.update({ 'use_cbs': True, 'fsm': batch[5], 'num_constraints': batch[6], 'min_constraints_to_satisfy': args.min_constraints_to_satisfy, }) tic = time.time() # captions, logprobs outputs = model(**inputs) time_meter += time.time() - tic all_caps = outputs[0] # batch_size * num_keep_best * max_len all_confs = torch.exp(outputs[1]) for img_key, caps, confs in zip(img_keys, all_caps, all_confs): res = [] for cap, conf in zip(caps, confs): cap = tokenizer.decode(cap.tolist(), skip_special_tokens=True) res.append({'caption': cap, 'conf': conf.item()}) if isinstance(img_key, torch.Tensor): img_key = img_key.item() yield img_key, json.dumps(res) logger.info( "Inference model computing time: {} seconds per batch".format( time_meter / (step + 1))) tsv_writer(gen_rows(), cache_file) return predict_file
def test(args, test_dataset, model, tokenizer, predict_file): args.test_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) test_sampler = SequentialSampler(test_dataset) cache_file = predict_file test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=args.test_batch_size, num_workers=args.num_workers) cls_token_id, sep_token_id, pad_token_id, mask_token_id, period_token_id = \ tokenizer.convert_tokens_to_ids( [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token, tokenizer.mask_token, '.'] ) model.eval() def gen_rows(): time_meter = 0 # restore existing results for long running inference tasks exist_key2pred = {} tmp_file = cache_file + '.tmp.copy' if op.isfile(tmp_file): with open(tmp_file, 'r') as fp: for line in fp: parts = line.strip().split('\t') if len(parts) == 2: exist_key2pred[parts[0]] = parts[1] with torch.no_grad(): for step, (img_keys, batch) in tqdm(enumerate(test_dataloader)): is_exist = True for k in img_keys: if k not in exist_key2pred: is_exist = False break if is_exist: for k in img_keys: yield k, exist_key2pred[k] continue batch = tuple(t.to(args.device) for t in batch) inputs = {'is_decode': True, 'input_ids': batch[0], 'attention_mask': batch[1], 'token_type_ids': batch[2], 'img_feats': batch[3], 'masked_pos': batch[4], 'do_sample': False, 'bos_token_id': cls_token_id, 'pad_token_id': pad_token_id, 'eos_token_ids': [sep_token_id, pad_token_id], 'mask_token_id': mask_token_id, # for adding od labels 'add_od_labels': args.add_od_labels, 'od_labels_start_posid': args.max_seq_a_length, # hyperparameters of beam search 'max_length': args.max_gen_length, 'num_beams': args.num_beams, "temperature": args.temperature, "top_k": args.top_k, "top_p": args.top_p, "repetition_penalty": args.repetition_penalty, "length_penalty": args.length_penalty, "num_return_sequences": args.num_return_sequences, "num_keep_best": args.num_keep_best, } if args.use_cbs: inputs.update({'use_cbs': True, 'fsm': batch[5], 'num_constraints': batch[6], 'min_constraints_to_satisfy': args.min_constraints_to_satisfy, }) tic = time.time() print ('INPUTS WITHOUT ONNX') outputs = model(**inputs) # sanity check onnx_input_names = sorted(list(inputs.keys())) def convert_onnx_input(onnx_in): if type(onnx_in) != torch.Tensor: response = torch.tensor(onnx_in) return response return onnx_in onnx_inputs = [convert_onnx_input(inputs[k]) for k in onnx_input_names] # convert to onnx if needed if True: #not os.path.exists('oscar.onnx'): symbolic_names = {0: 'batch_size', 1: 'max_seq_len'} print('INPUTS_WITH_ONNX') torch.onnx.export(model, # model being run args=tuple(onnx_inputs), # model input (or a tuple for multiple inputs) f='oscar.onnx', # where to save the model (can be a file or file-like object) opset_version=10, # the ONNX version to export the model to input_names=onnx_input_names, output_names=['output'], # the model's output names dynamic_axes={'input_ids': symbolic_names, # variable length axes 'attention_mask' : symbolic_names, 'token_type_ids' : symbolic_names, 'img_feats' : symbolic_names, 'masked_pos' : symbolic_names, 'is_decode' : symbolic_names, 'output' : symbolic_names}) print("Model exported at ", 'oscar.onnx') def remove_initializer_from_input(model_path): model = onnx.load(model_path) if model.ir_version < 4: print( 'Model with ir_version below 4 requires to include initilizer in graph input' ) return inputs = model.graph.input name_to_input = {} for input in inputs: name_to_input[input.name] = input for initializer in model.graph.initializer: if initializer.name in name_to_input: inputs.remove(name_to_input[initializer.name]) onnx.save(model, model_path) print("Input initializer removed") remove_initializer_from_input('oscar.onnx') # Load the ONNX model onnx_model = onnx.load("oscar.onnx") # Check that the IR is well formed onnx.checker.check_model(onnx_model) # Print a human readable representation of the graph onnx.helper.printable_graph(onnx_model.graph) #outputs = model(**inputs) sess_options = onnxruntime.SessionOptions() session = onnxruntime.InferenceSession('oscar.onnx', sess_options) onnx_inputs = { 'input_ids': batch[0].numpy(), 'attention_mask': batch[1].float().numpy(), #'token_type_ids': batch[2].numpy(), 'img_feats': batch[3].numpy(), #'masked_pos': batch[4].numpy() } print(onnx_inputs['attention_mask']) onnx_outputs = session.run(None, onnx_inputs) print(onnx_outputs[0]) outputs = onnx_outputs[0] time_meter += time.time() - tic all_caps = outputs[0] # batch_size * num_keep_best * max_len all_confs = torch.exp(outputs[1]) for img_key, caps, confs in zip(img_keys, all_caps, all_confs): res = [] for cap, conf in zip(caps, confs): cap = tokenizer.decode(cap.tolist(), skip_special_tokens=True) res.append({'caption': cap, 'conf': conf.item()}) if isinstance(img_key, torch.Tensor): img_key = img_key.item() yield img_key, json.dumps(res) break logger.info("Inference model computing time: {} seconds per batch".format(time_meter / (step+1))) tsv_writer(gen_rows(), cache_file) return predict_file