def test(args, test_dataset, model, tokenizer, predict_file):
    args.test_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    test_sampler = SequentialSampler(test_dataset)
    cache_file = predict_file

    test_dataloader = DataLoader(test_dataset,
                                 sampler=test_sampler,
                                 batch_size=args.test_batch_size,
                                 num_workers=args.num_workers)

    cls_token_id, sep_token_id, pad_token_id, mask_token_id, period_token_id = \
        tokenizer.convert_tokens_to_ids( [tokenizer.cls_token,
            tokenizer.sep_token, tokenizer.pad_token, tokenizer.mask_token, '.']
        )
    model.eval()

    def gen_rows():
        time_meter = 0
        # restore existing results for long running inference tasks
        exist_key2pred = {}
        tmp_file = cache_file + '.tmp.copy'
        if op.isfile(tmp_file):
            with open(tmp_file, 'r') as fp:
                for line in fp:
                    parts = line.strip().split('\t')
                    if len(parts) == 2:
                        exist_key2pred[parts[0]] = parts[1]

        with torch.no_grad():
            for step, (img_keys, batch) in tqdm(enumerate(test_dataloader)):
                is_exist = True
                for k in img_keys:
                    if k not in exist_key2pred:
                        is_exist = False
                        break
                if is_exist:
                    for k in img_keys:
                        yield k, exist_key2pred[k]
                    continue
                batch = tuple(t.to(args.device) for t in batch)
                inputs = {
                    'is_decode': True,
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'token_type_ids': batch[2],
                    'img_feats': batch[3],
                    'masked_pos': batch[4],
                    'do_sample': False,
                    'bos_token_id': cls_token_id,
                    'pad_token_id': pad_token_id,
                    'eos_token_ids': [sep_token_id, pad_token_id],
                    'mask_token_id': mask_token_id,
                    # for adding od labels
                    'add_od_labels': args.add_od_labels,
                    'od_labels_start_posid': args.max_seq_a_length,
                    # for disable image features
                    #'disable_img_features': args.disable_img_features,
                    # for selecting top object tags at inference
                    #'keep_top_percentage_tag_conf_threshold': args.keep_top_percentage_tag_conf_threshold,
                    #'keep_top_percentage_tag': args.keep_top_percentage_tag,

                    # hyperparameters of beam search
                    'max_length': args.max_gen_length,
                    'num_beams': args.num_beams,
                    "temperature": args.temperature,
                    "top_k": args.top_k,
                    "top_p": args.top_p,
                    "repetition_penalty": args.repetition_penalty,
                    "length_penalty": args.length_penalty,
                    "num_return_sequences": args.num_return_sequences,
                    "num_keep_best": args.num_keep_best,
                }
                if args.use_cbs:
                    inputs.update({
                        'use_cbs':
                        True,
                        'fsm':
                        batch[5],
                        'num_constraints':
                        batch[6],
                        'min_constraints_to_satisfy':
                        args.min_constraints_to_satisfy,
                    })
                tic = time.time()
                # captions, logprobs
                outputs = model(**inputs)
                time_meter += time.time() - tic
                all_caps = outputs[0]  # batch_size * num_keep_best * max_len
                all_confs = torch.exp(outputs[1])

                for img_key, caps, confs in zip(img_keys, all_caps, all_confs):
                    res = []
                    for cap, conf in zip(caps, confs):
                        cap = tokenizer.decode(cap.tolist(),
                                               skip_special_tokens=True)
                        res.append({'caption': cap, 'conf': conf.item()})
                    if isinstance(img_key, torch.Tensor):
                        img_key = img_key.item()
                    yield img_key, json.dumps(res)

        logger.info(
            "Inference model computing time: {} seconds per batch".format(
                time_meter / (step + 1)))

    tsv_writer(gen_rows(), cache_file)
    return predict_file
Example #2
0
def test(args, test_dataset, model, tokenizer, predict_file):
    args.test_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    test_sampler = SequentialSampler(test_dataset)
    cache_file = predict_file

    test_dataloader = DataLoader(test_dataset, sampler=test_sampler,
            batch_size=args.test_batch_size, num_workers=args.num_workers)

    cls_token_id, sep_token_id, pad_token_id, mask_token_id, period_token_id = \
        tokenizer.convert_tokens_to_ids( [tokenizer.cls_token, 
            tokenizer.sep_token, tokenizer.pad_token, tokenizer.mask_token, '.']
        )
    model.eval()

    def gen_rows():
        time_meter = 0
        # restore existing results for long running inference tasks
        exist_key2pred = {}
        tmp_file = cache_file + '.tmp.copy'
        if op.isfile(tmp_file):
            with open(tmp_file, 'r') as fp:
                for line in fp:
                    parts = line.strip().split('\t')
                    if len(parts) == 2:
                        exist_key2pred[parts[0]] = parts[1]

        with torch.no_grad():
            for step, (img_keys, batch) in tqdm(enumerate(test_dataloader)):
                is_exist = True
                for k in img_keys:
                    if k not in exist_key2pred:
                        is_exist = False
                        break
                if is_exist:
                    for k in img_keys:
                        yield k, exist_key2pred[k]
                    continue
                batch = tuple(t.to(args.device) for t in batch)
                inputs = {'is_decode': True,
                    'input_ids': batch[0], 'attention_mask': batch[1],
                    'token_type_ids': batch[2], 'img_feats': batch[3],
                    'masked_pos': batch[4],
                    'do_sample': False,
                    'bos_token_id': cls_token_id,
                    'pad_token_id': pad_token_id,
                    'eos_token_ids': [sep_token_id, pad_token_id],
                    'mask_token_id': mask_token_id,
                    # for adding od labels
                    'add_od_labels': args.add_od_labels, 'od_labels_start_posid': args.max_seq_a_length,

                    # hyperparameters of beam search
                    'max_length': args.max_gen_length,
                    'num_beams': args.num_beams,
                    "temperature": args.temperature,
                    "top_k": args.top_k,
                    "top_p": args.top_p,
                    "repetition_penalty": args.repetition_penalty,
                    "length_penalty": args.length_penalty,
                    "num_return_sequences": args.num_return_sequences,
                    "num_keep_best": args.num_keep_best,
                }
                if args.use_cbs:
                    inputs.update({'use_cbs': True,
                        'fsm': batch[5],
                        'num_constraints': batch[6],
                        'min_constraints_to_satisfy': args.min_constraints_to_satisfy,
                    })
                tic = time.time()

                print ('INPUTS WITHOUT ONNX')
                outputs = model(**inputs) # sanity check

                onnx_input_names = sorted(list(inputs.keys()))
                
                def convert_onnx_input(onnx_in):
                    if type(onnx_in) != torch.Tensor:
                        response = torch.tensor(onnx_in)
                        return response
                    return onnx_in
                onnx_inputs = [convert_onnx_input(inputs[k]) for k in onnx_input_names]

                # convert to onnx if needed
                if True: #not os.path.exists('oscar.onnx'):
                    symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
                    print('INPUTS_WITH_ONNX')
                    torch.onnx.export(model,                                            # model being run
                                    args=tuple(onnx_inputs),                      # model input (or a tuple for multiple inputs)
                                    f='oscar.onnx',                              # where to save the model (can be a file or file-like object)
                                    opset_version=10,                      # the ONNX version to export the model to
                                    input_names=onnx_input_names,
                                    output_names=['output'],                    # the model's output names
                                    dynamic_axes={'input_ids': symbolic_names,        # variable length axes
                                                    'attention_mask' : symbolic_names,
                                                    'token_type_ids' : symbolic_names,
                                                    'img_feats' : symbolic_names,
                                                    'masked_pos' : symbolic_names,
                                                    'is_decode' : symbolic_names,
                                                    'output' : symbolic_names})
                    print("Model exported at ", 'oscar.onnx')

                    def remove_initializer_from_input(model_path):

                        model = onnx.load(model_path)
                        if model.ir_version < 4:
                            print(
                                'Model with ir_version below 4 requires to include initilizer in graph input'
                            )
                            return

                        inputs = model.graph.input
                        name_to_input = {}
                        for input in inputs:
                            name_to_input[input.name] = input

                        for initializer in model.graph.initializer:
                            if initializer.name in name_to_input:
                                inputs.remove(name_to_input[initializer.name])

                        onnx.save(model, model_path)
                        print("Input initializer removed")

                    remove_initializer_from_input('oscar.onnx')

                    # Load the ONNX model
                    onnx_model = onnx.load("oscar.onnx")

                    # Check that the IR is well formed
                    onnx.checker.check_model(onnx_model)

                    # Print a human readable representation of the graph
                    onnx.helper.printable_graph(onnx_model.graph)

                #outputs = model(**inputs)

                sess_options = onnxruntime.SessionOptions()

                session = onnxruntime.InferenceSession('oscar.onnx', sess_options)

                onnx_inputs = {
                        'input_ids':      batch[0].numpy(),
                        'attention_mask': batch[1].float().numpy(),
                        #'token_type_ids': batch[2].numpy(),
                        'img_feats': batch[3].numpy(),
                        #'masked_pos': batch[4].numpy()
                }
                print(onnx_inputs['attention_mask'])
                onnx_outputs = session.run(None, onnx_inputs)
                print(onnx_outputs[0])
                outputs = onnx_outputs[0]


                time_meter += time.time() - tic
                all_caps = outputs[0]  # batch_size * num_keep_best * max_len
                all_confs = torch.exp(outputs[1])

                for img_key, caps, confs in zip(img_keys, all_caps, all_confs):
                    res = []
                    for cap, conf in zip(caps, confs):
                        cap = tokenizer.decode(cap.tolist(), skip_special_tokens=True)
                        res.append({'caption': cap, 'conf': conf.item()})
                    if isinstance(img_key, torch.Tensor):
                        img_key = img_key.item()
                    yield img_key, json.dumps(res)
                break

        logger.info("Inference model computing time: {} seconds per batch".format(time_meter / (step+1)))

    tsv_writer(gen_rows(), cache_file)
    return predict_file