Exemple #1
0
def alignment(dfhyp, dfref, align_type="SID", len_min=1, len_max=256):
    id2ref = {}

    for row in dfref.itertuples():
        id2ref[row.utt_id] = str2ints(row.token_id)
    
    outs = []

    for row in tqdm(dfhyp.itertuples()):
        hyp_token_id = str2ints(row.token_id)
        ref_token_id = id2ref[row.utt_id]

        if len(hyp_token_id) < len_min or len(hyp_token_id) > len_max:
            continue

        _, wer_dict = compute_wer(hyp_token_id, ref_token_id)
        error_list = wer_dict["error_list"]

        align_list = []
        del_flag = False

        if align_type == "SI":
            align_list = [e for e in error_list if e != "D"]
        elif align_type == "SID":
            for e in error_list:
                if e == "D":
                    # pass `D` to left
                    if len(align_list) > 0 and align_list[-1] == "C":
                        align_list[-1] == "D"
                    else: # to right
                        del_flag = True
                else:
                    if del_flag and e == "C":
                        align_list.append("D")
                    else:
                        align_list.append(e)
                    del_flag = False

        assert len(hyp_token_id) == len(align_list)

        outs.append(
            (row.utt_id, row.score_asr, row.token_id, row.text, row.reftext, " ".join(align_list))
        )
    
    df = pd.DataFrame(
        outs, columns=["utt_id", "score_asr", "token_id", "text", "reftext", "error_label"]
    )

    return df
Exemple #2
0
    def __getitem__(self, idx):
        utt_id = self.data.loc[idx]["utt_id"]
        token_id = str2ints(self.data.loc[idx]["token_id"])
        phone_token_id = str2ints(self.data.loc[idx]["phone_token_id"])

        if self.add_sos_eos:
            token_id = [eos_id] + token_id + [eos_id]

        y = torch.tensor(token_id, dtype=torch.long)
        p = torch.tensor(phone_token_id, dtype=torch.long)

        if self.textaug is not None:
            p = self.textaug(p)

        if self.phase == "train":
            if self.lm_type in ["pelectra", "pbert"]:
                if self.mask_insert_poisson_lam > 0:
                    y_in, label = create_masked_lm_label_insert(
                        y,
                        mask_id=self.mask_id,
                        num_to_mask=self.num_to_mask,
                        mask_proportion=self.mask_proportion,
                        random_num_to_mask=self.random_num_to_mask,
                        insert_poisson_lam=self.mask_insert_poisson_lam,
                        pad_id=self.pad_id,
                    )
                else:
                    y_in, label = create_masked_lm_label(
                        y,
                        mask_id=self.mask_id,
                        num_to_mask=self.num_to_mask,
                        mask_proportion=self.mask_proportion,
                        random_num_to_mask=self.random_num_to_mask,
                    )
            elif self.lm_type == "ptransformer":
                y_in = y[:-1]
                label = y[1:]
            elif self.lm_type == "pctc":
                y_in = y
                label = p
        else:
            y_in = y
            label = None

        plen = p.size(0)
        ylen = y_in.size(0)

        return utt_id, p, plen, y_in, ylen, label
Exemple #3
0
    def __getitem__(self, idx):
        utt_id = self.data.loc[idx]["utt_id"]
        text = self.data.loc[idx]["text"]

        feat_path = self.data.loc[idx]["feat_path"]
        x = np.load(feat_path)[:, :self.feat_dim]

        if self.specaug is not None:
            x = self.specaug(x)

        x = torch.tensor(x, dtype=torch.float)  # float32

        if self.num_framestacks > 1:
            x = self._stack_frames(x, self.num_framestacks)

        xlen = x.size(0)  # `xlen` is based on length after frame stacking

        token_id = str2ints(self.data.loc[idx]["token_id"])
        y = torch.tensor(token_id, dtype=torch.long)  # int64
        ylen = y.size(0)

        if "phone_token_id" in self.data:
            phone_token_id = str2ints(self.data.loc[idx]["phone_token_id"])
            p = torch.tensor(phone_token_id, dtype=torch.long)
            plen = p.size(0)
        else:
            p, plen = None, None

        # for knowledge distillation
        if self.data_kd is not None:
            utt_id_nosp = get_utt_id_nosp(utt_id)

            if utt_id_nosp in self.data_kd:
                data_kd_utt = self.data_kd[utt_id_nosp]
            else:
                data_kd_utt = []
                logging.warning(f"soft label: {utt_id_nosp} not found")

            soft_label = create_soft_label(data_kd_utt,
                                           ylen,
                                           self.vocab_size,
                                           self.lsm_prob,
                                           add_eos=self.add_eos)
        else:
            soft_label = None

        return utt_id, x, xlen, y, ylen, text, p, plen, soft_label
Exemple #4
0
def accuracy(labels, dfref, vocab=None):
    id2ref = {}
    cnt, cntacc1, cntacck = 0, 0, 0

    for row in dfref.itertuples():
        id2ref[row.utt_id] = str2ints(row.token_id)
        # assert row.utt_id in labels.keys()

    for utt_id, label in tqdm(labels.items()):
        ref_token_id = id2ref[utt_id]
        cnt += len(label)

        if vocab is not None:
            print(f"# utt_id: {utt_id}")

            ref_text = vocab.ids2tokens(ref_token_id)
            for i, vps in enumerate(label):
                # mask i-th token
                ref_text_masked = ref_text.copy()
                ref_text_masked[i] = "<mask>"
                print(" ".join(ref_text_masked))

                for v, p in vps:
                    print(f"{vocab.id2token(v)}: {p:.2f}", end=" ")
                print()

        for i, vps in enumerate(label):
            v1, _ = vps[0]
            cntacc1 += int(v1 == ref_token_id[i])

            for v, _ in vps:
                cntacck += int(v == ref_token_id[i])

    acc1 = (cntacc1 / cnt) * 100
    acck = (cntacck / cnt) * 100

    return acc1, acck, cnt
Exemple #5
0
    def __getitem__(self, idx):
        utt_id = self.data.loc[idx]["utt_id"]
        token_id = str2ints(self.data.loc[idx]["token_id"])
        if self.add_sos_eos:
            token_id = [eos_id] + token_id + [eos_id]

        y = torch.tensor(token_id, dtype=torch.long)
        
        if "error_label" in self.data:
            error_label = self.data.loc[idx]["error_label"].split()
            error_label = torch.tensor([e != "C" for e in error_label], dtype=float)
        else:
            error_label = None

        if self.phase == "train":
            if self.lm_type in ["bert", "electra"]:
                y_in, label = create_masked_lm_label(
                    y,
                    mask_id=self.mask_id,
                    num_to_mask=self.num_to_mask,
                    mask_proportion=self.mask_proportion,
                    random_num_to_mask=self.random_num_to_mask,
                )
            elif self.lm_type in ["transformer", "rnn"]:
                assert len(y) > 1
                y_in = y[:-1]
                label = y[1:]
            elif self.lm_type in ["electra-disc", "pelectra-disc"]:
                y_in = y
                label = None
        else:
            y_in = y
            label = None

        ylen = y_in.size(0)

        return utt_id, y_in, ylen, label, error_label
def score_lm(df, model, device, mask_id=None, vocab=None, num_samples=-1):
    ys, ylens, score_lms_all = [], [], []

    utt_id = None
    cnt_utts = 0

    for i, row in enumerate(df.itertuples()):
        if row.utt_id != utt_id:
            cnt_utts += 1
            utt_id = row.utt_id
        if num_samples > 0 and (cnt_utts + 1) > num_samples:
            return

        y = str2ints(row.token_id)
        ys.append(torch.tensor(y))
        ylens.append(len(y))

        if len(ys) < BATCH_SIZE and i != (len(df) - 1):
            continue

        ys_pad = pad_sequence(ys, batch_first=True).to(device)
        ylens = torch.tensor(ylens).to(device)

        score_lms = model.score(ys_pad, ylens, batch_size=BATCH_SIZE)

        if vocab is not None:  # debug mode
            for y, score_lm in zip(ys, score_lms):
                logging.debug(
                    f"{' '.join(vocab.ids2words(tensor2np(y)))}: {score_lm:.3f}"
                )

        score_lms_all.extend(score_lms)
        ys, ylens = [], []

    df["score_lm"] = score_lms_all
    return df
Exemple #7
0
def main(args):
    data = pd.read_table(args.tsv_path)
    # data = data.dropna()

    print(f"Read tsv ({len(data)} samples)")

    # shuffle
    if args.shuffle:
        data = data.sample(frac=1, random_state=0).reset_index(drop=True)
        print(f"Data shuffled")
    else:
        print(f"Data NOT shuffled")

    # concat sentences (its lengths is NOT always the same as args.max_len)
    if args.task == "P2W":
        utt_id_start, utt_id_end = "", ""
        phone_token_id_concat = [args.phone_eos_id]
        phone_text_concat = "<eos>"
        token_id_concat = [args.eos_id]
        text_concat = "<eos>"

        outs = []  # utt_id, phone_token_id, phone_text, token_id, text

        for row in tqdm(data.itertuples()):
            utt_id = row.utt_id
            phone_token_id = str2ints(row.phone_token_id) + [args.phone_eos_id]
            token_id = str2ints(row.token_id) + [args.eos_id]
            phone_text = f" {row.phone_text} <eos>"
            text = f" {row.text} <eos>"

            if len(phone_token_id) + 1 > args.max_src_len:
                continue
            if len(token_id) + 1 > args.max_len:
                continue

            if utt_id_start == "":
                utt_id_start = row.utt_id
            utt_id_end = row.utt_id

            # NOTE: filter by its length
            if (len(phone_token_id_concat) + len(phone_token_id) >
                    args.max_src_len
                    or len(token_id_concat) + len(token_id) > args.max_len):
                if (len(phone_token_id_concat) >= args.min_src_len
                        and len(token_id_concat) >= args.min_len):
                    outs.append((
                        f"{utt_id_start}-{utt_id_end}",
                        ints2str(phone_token_id_concat),
                        phone_text_concat,
                        ints2str(token_id_concat),
                        text_concat,
                    ))

                utt_id_start, utt_id_end = "", ""
                phone_token_id_concat = [args.phone_eos_id]
                phone_text_concat = "<eos>"
                token_id_concat = [args.eos_id]
                text_concat = "<eos>"

            else:
                phone_token_id_concat.extend(phone_token_id)
                token_id_concat.extend(token_id)
                phone_text_concat += phone_text
                text_concat += text

        if utt_id_start != "":
            if (len(phone_token_id_concat) >= args.min_src_len
                    and len(token_id_concat) >= args.min_len):
                outs.append((
                    f"{utt_id_start}-{utt_id_end}",
                    ints2str(phone_token_id_concat),
                    phone_text_concat,
                    ints2str(token_id_concat),
                    text_concat,
                ))
        data = pd.DataFrame(
            outs,
            columns=[
                "utt_id",
                "phone_token_id",
                "phone_text",
                "token_id",
                "text",
            ],
        )

    # concat tokens (its lengths is always the same as args.max_len)
    # NOTE: sentence longer than max_len is skipped
    elif args.task == "LM":
        utt_id_start, utt_id_end = "", ""
        token_id_concat = [args.eos_id]

        outs = []  # utt_id, token_id, text

        for row in tqdm(data.itertuples()):
            utt_id = row.utt_id
            token_id = str2ints(row.token_id) + [args.eos_id]

            if utt_id_start == "":
                utt_id_start = row.utt_id
            utt_id_end = row.utt_id

            if len(token_id) > args.max_len:
                continue

            if len(token_id_concat) + len(token_id) < args.max_len:
                token_id_concat += token_id
            else:
                remainder = args.max_len - len(token_id_concat)
                token_id_concat += token_id[:remainder]
                assert len(token_id_concat) == args.max_len
                outs.append((f"{utt_id_start}-{utt_id_end}",
                             ints2str(token_id_concat)))
                utt_id_start, utt_id_end = "", ""
                token_id_concat = token_id[remainder:]

        # NOTE: text cannot provide
        data = pd.DataFrame(
            outs,
            columns=["utt_id", "token_id"],
        )

    elif args.task == "LMall":
        if args.eos_id >= 0:
            token_id_all = [args.eos_id]
        else:
            token_id_all = []

        # NOTE: First, concat all tokens
        for row in data.itertuples():
            token_id_all.extend(str2ints(row.token_id))
            if args.eos_id >= 0:
                token_id_all.append(args.eos_id)

        # save memory
        del data
        gc.collect()

        start = 0
        utt_id_prefix = os.path.splitext(os.path.basename(args.tsv_path))[0]
        outs = []  # utt_id, token_id

        for i in range(args.rep):
            start = 0 + i * (args.max_len // args.rep)
            while start + args.max_len < len(token_id_all):
                end = start + args.max_len
                outs.append((f"{utt_id_prefix}-{i}-{start}",
                             ints2str(token_id_all[start:end])))
                start = end

        # NOTE: text cannot provide
        data = pd.DataFrame(
            outs,
            columns=["utt_id", "token_id"],
        )

    if args.out is None:
        data.to_csv(f"{os.path.splitext(args.tsv_path)[0]}_concat.tsv",
                    sep="\t",
                    index=False)
    else:
        data.to_csv(args.out, sep="\t", index=False)
Exemple #8
0
def make_lm_label(
    df,
    model,
    device,
    save_path,
    topk=8,
    temp=3.0,
    add_sos_eos=False,
    eos_id=2,
    max_seq_len=256,
):
    labels = {}

    utt_ids, ys, ylens, start_poss, end_poss = [], [], [], [], []  # batch

    for i, row in enumerate(df.itertuples()):
        ids = str2ints(row.token_id)

        if add_sos_eos:
            if len(ids) <= max_seq_len - 2:
                ids = [eos_id] + ids + [eos_id]
                start_pos = row.start_pos + 1
                end_pos = row.end_pos + 1
            else:
                # reduce context
                ids = [eos_id] + ids[1:-1] + [eos_id]
                start_pos = row.start_pos
                end_pos = row.end_pos
        else:
            start_pos = row.start_pos
            end_pos = row.end_pos

        y = torch.tensor(ids)
        ylen = len(ids)

        utt_ids.append(row.utt_id)
        ys.append(y)
        ylens.append(ylen)
        start_poss.append(start_pos)
        end_poss.append(end_pos)

        # batchify
        if (i + 1) % BATCH_SIZE == 0 or (i + 1) == len(df):
            bs = len(ys)
            ys_pad = pad_sequence(ys, batch_first=True).to(device)
            ylens = torch.tensor(ylens).to(device)

            with torch.no_grad():
                logits = model(ys_pad, ylens)

            for b in range(bs):
                utt_id = utt_ids[b]
                start_pos = start_poss[b]
                end_pos = end_poss[b]
                y = ys[b]

                for pos in range(start_pos, end_pos):
                    if pos == 0:
                        v_topk = np.array([y[pos]])
                        p_topk = np.array([1.0])
                        logging.warning(f"hard label is used: {v_topk}")
                    else:
                        o_sorted, v_sorted = torch.sort(logits[b, pos - 1],
                                                        descending=True)
                        o_topk = o_sorted[:topk]
                        v_topk = tensor2np(v_sorted[:topk])
                        p_topk = tensor2np(
                            torch.softmax((o_topk / temp), dim=0))

                    label = []
                    for v, p in zip(v_topk, p_topk):
                        # NOTE: do not add <eos> to soft labels
                        if add_sos_eos and v == eos_id:
                            continue
                        label.append((v, p))

                    if utt_id not in labels:  # first token in utterance
                        labels[utt_id] = [label]
                    else:
                        labels[utt_id].append(label)

            utt_ids, ys, ylens, start_poss, end_poss = [], [], [], [], []

        if (i + 1) % LOG_STEP == 0:
            logging.info(f"{(i+1):>4} / {len(df):>4}")
        if (i + 1) == SAVE_STEP:
            save_tmp_path = save_path + ".tmp"
            with open(save_tmp_path, "wb") as f:
                pickle.dump(labels, f)
            logging.info(f"pickle is saved to {save_tmp_path}")

    with open(save_path, "wb") as f:
        pickle.dump(labels, f)
    logging.info(f"pickle is saved to {save_path}")
Exemple #9
0
def get_ylen(token_id):
    return len(str2ints(token_id))