コード例 #1
0
ファイル: pytorch_SUT.py プロジェクト: xz10620/inference
class BERT_PyTorch_SUT():
    def __init__(self):
        print("Loading BERT configs...")
        with open("bert_config.json") as f:
            config_json = json.load(f)

        config = BertConfig(
            attention_probs_dropout_prob=config_json["attention_probs_dropout_prob"],
            hidden_act=config_json["hidden_act"],
            hidden_dropout_prob=config_json["hidden_dropout_prob"],
            hidden_size=config_json["hidden_size"],
            initializer_range=config_json["initializer_range"],
            intermediate_size=config_json["intermediate_size"],
            max_position_embeddings=config_json["max_position_embeddings"],
            num_attention_heads=config_json["num_attention_heads"],
            num_hidden_layers=config_json["num_hidden_layers"],
            type_vocab_size=config_json["type_vocab_size"],
            vocab_size=config_json["vocab_size"])

        print("Loading PyTorch model...")
        self.model = BertForQuestionAnswering(config)
        self.model.eval()
        self.model.cuda()
        self.model.load_state_dict(torch.load("build/data/bert_tf_v1_1_large_fp32_384_v2/model.pytorch"))

        print("Constructing SUT...")
        self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries, self.process_latencies)
        print("Finished constructing SUT.")

        self.qsl = get_squad_QSL()

    def issue_queries(self, query_samples):
        with torch.no_grad():
            for i in range(len(query_samples)):
                eval_features = self.qsl.get_features(query_samples[i].index)
                start_scores, end_scores = self.model.forward(input_ids=torch.LongTensor(eval_features.input_ids).unsqueeze(0).cuda(),
                    attention_mask=torch.LongTensor(eval_features.input_mask).unsqueeze(0).cuda(),
                    token_type_ids=torch.LongTensor(eval_features.segment_ids).unsqueeze(0).cuda())
                output = torch.stack([start_scores, end_scores], axis=-1).squeeze(0).cpu().numpy()

                response_array = array.array("B", output.tobytes())
                bi = response_array.buffer_info()
                response = lg.QuerySampleResponse(query_samples[i].id, bi[0], bi[1])
                lg.QuerySamplesComplete([response])

    def flush_queries(self):
        pass

    def process_latencies(self, latencies_ns):
        pass

    def __del__(self):
        print("Finished destroying SUT.")
コード例 #2
0
print("Total examples available: {}".format(TOTAL_EXAMPLES))

## Processing by batches:
#
BATCH_COUNT = int(os.getenv('CK_BATCH_COUNT')) or TOTAL_EXAMPLES

encoded_accuracy_log = []
io_dump_structure = {}
with torch.no_grad():
    for i in range(BATCH_COUNT):
        selected_feature = eval_features[i]

        start_scores, end_scores = model.forward(
            input_ids=torch.LongTensor(
                selected_feature.input_ids).unsqueeze(0).to(TORCH_DEVICE),
            attention_mask=torch.LongTensor(
                selected_feature.input_mask).unsqueeze(0).to(TORCH_DEVICE),
            token_type_ids=torch.LongTensor(
                selected_feature.segment_ids).unsqueeze(0).to(TORCH_DEVICE))
        output = torch.stack([start_scores, end_scores],
                             axis=-1).squeeze(0).cpu().numpy()

        encoded_accuracy_log.append({
            'qsl_idx': i,
            'data': output.tobytes().hex()
        })
        print("Batch #{}/{} done".format(i + 1, BATCH_COUNT))

        if IO_DUMP_PATH:
            io_dump_structure[f'sample_{i+1}'] = {
                'input_ids': selected_feature.input_ids,