コード例 #1
0
    def _feature2records(self, idx, feature: Dict, mode: str) -> List[Dict]:
        record = dict(idx=idx, **feature)
        truncate_record(record=record,
                        max_len=self.max_len,
                        keys=["token_ids", "segment_ids", "tokens"])
        if mode == "train":
            labels = feature.get("labels")
            if labels is None:
                raise ValueError("no labels given in train mode!")
            label = random.choice(labels)
            tgt_word = random.choice(self.label2word[label])
            tokened = self.tokenizer.do_tokenize(tgt_word)
            tgt_token_span = tokened["tokens"][1:-1]
            tgt_token_span_id = tokened["token_ids"][1:-1]
            mask_start, mask_end = record["mask_span"]
            assert len(tgt_token_span) == mask_end - mask_start

            tgt_tokens = copy.copy(record["tokens"])
            tgt_token_ids = copy.copy(record["token_ids"])
            token_output = [0] * len(tgt_token_ids)
            tgt_tokens[mask_start:mask_end] = tgt_token_span
            tgt_token_ids[mask_start:mask_end] = tgt_token_span_id
            token_output[mask_start:mask_end] = tgt_token_span_id
            record.update(target_tokens=tgt_tokens,
                          tgt_token_ids=tgt_token_ids,
                          token_output=token_output)
        return [record]
コード例 #2
0
 def _feature2records(self, idx, feature: Dict, mode: str) -> List[Dict]:
     record = dict(idx=idx, **feature)
     truncate_record(record=record, max_len=self.max_len, keys=["token_ids", "segment_ids", "tokens"])
     if mode == "train":
         labels = feature.get("labels")
         if labels is None:
             raise ValueError("no labels given in train mode!")
         classify_output = get_classify_output(labels, self.label2id, self.sparse_label)
         record.update(classify_output=classify_output)
     return [record]
コード例 #3
0
 def _feature2records(self,
                      idx,
                      feature: Dict,
                      mode: str,
                      only_copy=False) -> List[Dict]:
     record = dict(idx=idx, **feature)
     if mode == "gen":
         record.update(score=0.)
         origin_token_len = record["origin_token_len"]
         record["token_ids"] = record["token_ids"][:origin_token_len]
         record["tokens"] = record["tokens"][:origin_token_len]
         record["segment_ids"] = record["segment_ids"][:origin_token_len]
     record.update(token_len=len(record["tokens"]))
     truncate_record(record=record,
                     max_len=self.max_len,
                     keys=["token_ids", "segment_ids", "tokens"])
     return [record]
コード例 #4
0
    def _feature2records(self, idx, feature: Dict, mode: str) -> List[Dict]:
        record = dict(idx=idx, **feature)
        if mode == "train":
            masked_tokens = feature.get("masked_tokens")
            if not masked_tokens:
                token_infos = [
                    e for e in enumerate(feature["tokens"])
                    if e[1] not in self.tokenizer.special_tokens
                ]
                masked_tokens = random.sample(
                    token_infos, int(len(token_infos) * self.mask_percent))
            token_output = [0] * len(feature["tokens"])
            tokens = copy.copy(feature["tokens"])
            token_ids = copy.copy(feature["token_ids"])

            for idx, token in masked_tokens:
                token_id = self.tokenizer.token2id(token)
                token_output[idx] = token_id
                if tokens[idx] != MASK:
                    r = random.random()
                    if r <= 0.8:
                        t = MASK
                    elif r <= 0.9:
                        t = random.choice(self.tokenizer.vocabs)
                    else:
                        t = token
                    tokens[idx] = t
                    token_ids[idx] = self.tokenizer.token2id(t)

            record.update(token_output=token_output,
                          masked_tokens=masked_tokens,
                          tokens=tokens,
                          token_ids=token_ids)
        truncate_record(
            record=record,
            max_len=self.max_len,
            keys=["token_ids", "segment_ids", "tokens", "token_output"])
        return [record]
コード例 #5
0
    def _feature2records(self, idx, feature: Dict, mode: str) -> List[dict]:
        record = dict(idx=idx, **feature)
        if mode == "train":
            text_spans = feature.get("text_spans")
            if text_spans is None:
                raise ValueError(f"not text_spans key found in train mode!")
            text_spans: TextSpans = [TextSpan(**e) for e in text_spans]
            char2token = record["char2token"]
            token_len = len(record["tokens"])
            classify_output = np.zeros(shape=(self.label_num, token_len,
                                              token_len))
            for text_span in text_spans:
                label_id = self.label2id[text_span.label]
                token_start = char2token[text_span.span[0]]
                token_end = char2token[text_span.span[1] - 1]
                classify_output[label_id][token_start][token_end] = 1

            record.update(classify_output=classify_output)
        truncate_record(record=record,
                        max_len=self.max_len,
                        keys=["token_ids", "segment_ids", "tokens"])

        return [record]
コード例 #6
0
    def _feature2records(self, idx, feature: Dict, mode: str) -> List[Dict]:
        record = dict(idx=idx, **feature)
        if mode == "train":
            text_spans = feature.get("text_spans")
            if text_spans is None:
                raise ValueError(f"not text_spans key found in train mode!")
            text_spans = [TextSpan(**e) for e in text_spans]
            token_label_func = get_overlap_token_label_sequence if self.multi_label else get_token_label_sequence
            target_token_label_sequence = token_label_func(
                feature["tokens"], text_spans, feature["char2token"],
                self.seq_label_strategy)
            classify_labels = token_label2classify_label_input(
                target_token_label_sequence, self.multi_label, self.label2id)
            record.update(
                target_token_label_sequence=target_token_label_sequence,
                classify_labels=classify_labels)

        truncate_record(
            record=record,
            max_len=self.max_len,
            keys=["token_ids", "segment_ids", "tokens", "classify_labels"])

        return [record]