예제 #1
0
def run_dataloader():
    """test dataloader"""
    parser = get_parser()

    # add model specific args
    parser = BertLabeling.add_model_specific_args(parser)

    # add all the available trainer options to argparse
    # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
    parser = Trainer.add_argparse_args(parser)

    args = parser.parse_args()
    args.workers = 0
    args.default_root_dir = "/mnt/data/mrc/train_logs/debug"

    model = BertLabeling(args)
    from tokenizers import BertWordPieceTokenizer
    tokenizer = BertWordPieceTokenizer.from_file(
        os.path.join(args.bert_config_dir, "vocab.txt"))

    loader = model.get_dataloader("dev", limit=1000)
    for d in loader:
        input_ids = d[0][0].tolist()
        match_labels = d[-1][0]
        start_positions, end_positions = torch.where(match_labels > 0)
        start_positions = start_positions.tolist()
        end_positions = end_positions.tolist()
        if not start_positions:
            continue
        print("=" * 20)
        print(tokenizer.decode(input_ids, skip_special_tokens=False))
        for start, end in zip(start_positions, end_positions):
            print(tokenizer.decode(input_ids[start:end + 1]))
예제 #2
0
 def from_jsonl_files(cls, input_files, vocab_file, **kwargs):
     if isinstance(input_files, str):
         input_files = [input_files]
     examples = []
     for f in input_files:
         with open(f, mode="rt", encoding="utf-8") as fin:
             for line in fin:
                 line = line.strip()
                 if not line:
                     continue
                 instance = json.loads(line)
                 sentence_a = instance[kwargs.pop("sentence_a_key",
                                                  "sentence_a")]
                 sentence_b = instance[kwargs.pop("sentence_b_key",
                                                  "sentence_b")]
                 label = instance[keras.pop("label_key", "label")]
                 examples.append(
                     ExampleForSpearman(
                         sentence_a=sentence_a,
                         sentence_b=sentence_b,
                         label=label,
                     ))
     tokenizer = BertWordPieceTokenizer.from_file(vocab_file,
                                                  lowercase=kwargs.pop(
                                                      "do_lower_case",
                                                      True))
     return cls(examples=examples, tokenzer=tokenizer, **kwargs)
예제 #3
0
def run_dataset():
    """test dataset"""
    import os
    from datasets.collate_functions import collate_to_max_length
    from torch.utils.data import DataLoader
    # zh datasets
    # bert_path = "/mnt/mrc/chinese_L-12_H-768_A-12"
    # json_path = "/mnt/mrc/zh_msra/mrc-ner.test"
    # # json_path = "/mnt/mrc/zh_onto4/mrc-ner.train"
    # is_chinese = True

    # en datasets
    bert_path = "data/bert-base-uncased"
    json_path = "data/ace2004/mrc-ner.train"
    # json_path = "/mnt/mrc/genia/mrc-ner.train"
    is_chinese = False

    vocab_file = os.path.join(bert_path, "vocab.txt")
    tokenizer = BertWordPieceTokenizer.from_file(vocab_file)
    dataset = MRCNERDataset(json_path=json_path,
                            tokenizer=tokenizer,
                            is_chinese=is_chinese)

    dataloader = DataLoader(dataset,
                            batch_size=32,
                            collate_fn=collate_to_max_length)

    for batch in dataloader:
        for tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx in zip(
                *batch):
            tokens = tokens.tolist()
            start_positions, end_positions = torch.where(match_labels > 0)
            start_positions = start_positions.tolist()
            end_positions = end_positions.tolist()
            if not start_positions:
                continue
            print("=" * 20)
            print(f"len: {len(tokens)}",
                  tokenizer.decode(tokens, skip_special_tokens=False))
            for start, end in zip(start_positions, end_positions):
                print(
                    str(sample_idx.item()),
                    str(label_idx.item()) + "\t" +
                    tokenizer.decode(tokens[start:end + 1]))

            print('tokens', tokens)
            # print(list(zip(tokenizer.convert_ids_to_tokens(tokens), start_labels.tolist())))
            print([tokenizer.id_to_token(id) for id in tokens])
            print('token_type_ids', token_type_ids)
            print('start_labels  ', start_labels)
            print('end_labels    ', end_labels)
            print('start_label_mask', start_label_mask)
            print('end_label_mask  ', end_label_mask)
            print('match_labels', match_labels)
            print('sample_idx', sample_idx)
            print('label_idx', label_idx)
            break

        break
    def test_spearman_for_sentence_embedding(self):
        examples = self._read_examples()
        tokenizer = BertWordPieceTokenizer.from_file(VOCAB_PATH,
                                                     lowercase=True)
        callback = SpearmanForSentenceEmbedding(examples, tokenizer)

        def _compute_spearman(model):
            callback.model = model
            callback.on_epoch_end(epoch=0, logs=None)

        _compute_spearman(HardNegativeSimCSE.from_pretrained(BERT_PATH))
        _compute_spearman(SupervisedSimCSE.from_pretrained(BERT_PATH))
        _compute_spearman(UnsupervisedSimCSE.from_pretrained(BERT_PATH))
    def test_basic_encode(self, bert_files):
        tokenizer = BertWordPieceTokenizer.from_file(bert_files["vocab"])

        # Encode with special tokens by default
        output = tokenizer.encode("My name is John", "pair")
        assert output.ids == [101, 2026, 2171, 2003, 2198, 102, 3940, 102]
        assert output.tokens == [
            "[CLS]",
            "my",
            "name",
            "is",
            "john",
            "[SEP]",
            "pair",
            "[SEP]",
        ]
        assert output.offsets == [
            (0, 0),
            (0, 2),
            (3, 7),
            (8, 10),
            (11, 15),
            (0, 0),
            (0, 4),
            (0, 0),
        ]
        assert output.type_ids == [0, 0, 0, 0, 0, 0, 1, 1]

        # Can encode without the special tokens
        output = tokenizer.encode("My name is John",
                                  "pair",
                                  add_special_tokens=False)
        assert output.ids == [2026, 2171, 2003, 2198, 3940]
        assert output.tokens == ["my", "name", "is", "john", "pair"]
        assert output.offsets == [(0, 2), (3, 7), (8, 10), (11, 15), (0, 4)]
        assert output.type_ids == [0, 0, 0, 0, 1]
예제 #6
0
    def get_dataloader(self, prefix="train", limit: int = None) -> DataLoader:
        """get training dataloader"""
        """
        load_mmap_dataset
        """
        json_path = os.path.join(self.data_dir, f"mrc-ner.{prefix}")
        vocab_path = os.path.join(self.bert_dir, "vocab.txt")
        dataset = MRCNERDataset(
            json_path=json_path,
            tokenizer=BertWordPieceTokenizer.from_file(vocab_path),
            max_length=self.args.max_length,
            is_chinese=self.chinese,
            pad_to_maxlen=False)

        if limit is not None:
            dataset = TruncateDataset(dataset, limit)

        dataloader = DataLoader(dataset=dataset,
                                batch_size=self.args.batch_size,
                                num_workers=self.args.workers,
                                shuffle=True if prefix == "train" else False,
                                collate_fn=collate_to_max_length)

        return dataloader
 def test_multiprocessing_with_parallelism(self, bert_files):
     tokenizer = BertWordPieceTokenizer.from_file(bert_files["vocab"])
     multiprocessing_with_parallelism(tokenizer, False)
     multiprocessing_with_parallelism(tokenizer, True)
예제 #8
0
 def encodings(self, bert_files):
     tokenizer = BertWordPieceTokenizer.from_file(bert_files["vocab"])
     single_encoding = tokenizer.encode("I love HuggingFace")
     pair_encoding = tokenizer.encode("I love HuggingFace", "Do you?")
     return single_encoding, pair_encoding