def encode_contexts(
    tokenizer:BertJapaneseTokenizer,
    contexts:Dict[str,str],
    max_seq_length:int,
    save_dir:str):
    """
    コンテキストのエンコードを行う。
    """
    os.makedirs(save_dir,exist_ok=True)

    for title,context in tqdm(contexts.items()):
        encoding = tokenizer.encode_plus(
            context,
            return_tensors="pt",
            add_special_tokens=True,
            padding="max_length",
            return_attention_mask=True,
            max_length=max_seq_length,
            truncation=True
        )

        input_ids=encoding["input_ids"].view(-1)

        title_hash=get_md5_hash(title)
        save_filepath=os.path.join(save_dir,title_hash+".pt")
        torch.save(input_ids,save_filepath)
Ejemplo n.º 2
0
def encode_examples(
    tokenizer:BertJapaneseTokenizer,
    examples:List[InputExample],
    contexts:Dict[str,str])->Dict[str,torch.Tensor]:
    """
    問題をエンコードする。
    """
    #最初の問題の選択肢の数を代表値として取得する。
    num_options=len(examples[0].options)

    input_ids=torch.empty(len(examples),num_options,BERT_MAX_SEQ_LENGTH,dtype=torch.long)
    attention_mask=torch.empty(len(examples),num_options,BERT_MAX_SEQ_LENGTH,dtype=torch.long)
    token_type_ids=torch.empty(len(examples),num_options,BERT_MAX_SEQ_LENGTH,dtype=torch.long)
    labels=torch.empty(len(examples),dtype=torch.long)

    for example_index,example in enumerate(tqdm(examples)):
        for option_index,option in enumerate(example.options):
            text_a=example.question+tokenizer.sep_token+option
            text_b=contexts[option]

            encoding = tokenizer.encode_plus(
                text_a,
                text_b,
                return_tensors="pt",
                add_special_tokens=True,
                pad_to_max_length=True,
                return_attention_mask=True,
                max_length=BERT_MAX_SEQ_LENGTH,
                truncation=True,
                truncation_strategy="only_second"   #コンテキストをtruncateする。
            )

            input_ids_tmp=encoding["input_ids"].view(-1)
            token_type_ids_tmp=encoding["token_type_ids"].view(-1)
            attention_mask_tmp=encoding["attention_mask"].view(-1)

            input_ids[example_index,option_index]=input_ids_tmp
            token_type_ids[example_index,option_index]=token_type_ids_tmp
            attention_mask[example_index,option_index]=attention_mask_tmp

            if example_index==0 and option_index<4:
                logger.info("option_index={}".format(option_index))
                logger.info("text_a: {}".format(text_a[:512]))
                logger.info("text_b: {}".format(text_b[:512]))
                logger.info("input_ids: {}".format(input_ids_tmp.detach().cpu().numpy()))
                logger.info("token_type_ids: {}".format(token_type_ids_tmp.detach().cpu().numpy()))
                logger.info("attention_mask: {}".format(attention_mask_tmp.detach().cpu().numpy()))

        labels[example_index]=example.label

    ret={
        "input_ids":input_ids,
        "token_type_ids":token_type_ids,
        "attention_mask":attention_mask,
        "labels":labels
    }

    return ret
def encode_captions(tokenizer: BertJapaneseTokenizer,
                    captions_dict: Dict[str, List[str]], max_seq_length: int,
                    save_dir: str):
    os.makedirs(save_dir, exist_ok=True)

    for filename, captions in tqdm(captions_dict.items()):
        input_ids = torch.empty(0, max_seq_length, dtype=torch.long)
        for caption in captions:
            encoding = tokenizer.encode_plus(caption,
                                             return_tensors="pt",
                                             add_special_tokens=True,
                                             padding="max_length",
                                             return_attention_mask=True,
                                             max_length=max_seq_length,
                                             truncation=True)
            input_ids_tmp = encoding["input_ids"]
            input_ids = torch.cat([input_ids, input_ids_tmp], dim=0)

        save_filename = os.path.splitext(filename)[0] + ".pt"
        save_filepath = os.path.join(save_dir, save_filename)
        torch.save(input_ids, save_filepath)