コード例 #1
0
ファイル: levenshtein_utils.py プロジェクト: veralily/fairseq
    def _get_ins_targets_cpu(in_tokens, out_tokens, padding_idx, unk_idx):
        in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)

        in_tokens_list = [[t for t in s if t != padding_idx]
                          for i, s in enumerate(in_tokens.tolist())]
        out_tokens_list = [[t for t in s if t != padding_idx]
                           for i, s in enumerate(out_tokens.tolist())]

        full_labels = libnat.suggested_ed2_path(in_tokens_list,
                                                out_tokens_list, padding_idx)
        mask_inputs = [[len(c) if c[0] != padding_idx else 0 for c in a[:-1]]
                       for a in full_labels]

        # generate labels
        masked_tgt_masks = []
        for mask_input in mask_inputs:
            mask_label = []
            for beam_size in mask_input[1:-1]:  # HACK 1:-1
                mask_label += [0] + [1 for _ in range(beam_size)]
            masked_tgt_masks.append(
                mask_label + [0 for _ in range(out_seq_len - len(mask_label))])
        mask_ins_targets = [
            mask_input[1:-1] +
            [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))]
            for mask_input in mask_inputs
        ]

        # transform to tensor
        masked_tgt_masks = torch.tensor(masked_tgt_masks,
                                        device=out_tokens.device).bool()
        mask_ins_targets = torch.tensor(mask_ins_targets,
                                        device=in_tokens.device)
        masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
        return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
コード例 #2
0
def _get_del_targets(in_tokens, out_tokens, padding_idx):
    libnat = load_libnat()

    out_seq_len = out_tokens.size(1)

    with torch.cuda.device_of(in_tokens):
        in_tokens_list = [
            [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
        ]
        out_tokens_list = [
            [t for t in s if t != padding_idx]
            for i, s in enumerate(out_tokens.tolist())
        ]

    full_labels = libnat.suggested_ed2_path(
        in_tokens_list, out_tokens_list, padding_idx
    )
    word_del_targets = [b[-1] for b in full_labels]
    word_del_targets = [
        labels + [0 for _ in range(out_seq_len - len(labels))]
        for labels in word_del_targets
    ]

    # transform to tensor
    word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
    return word_del_targets
コード例 #3
0
def _get_del_targets(in_tokens, out_tokens, padding_idx):
    try:
        from fairseq import libnat
    except ImportError as e:
        import sys

        sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
        raise e
    out_seq_len = out_tokens.size(1)

    in_tokens_list = [
        [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
    ]
    out_tokens_list = [
        [t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist())
    ]

    full_labels = libnat.suggested_ed2_path(
        in_tokens_list, out_tokens_list, padding_idx
    )
    word_del_targets = [b[-1] for b in full_labels]
    word_del_targets = [
        labels + [0 for _ in range(out_seq_len - len(labels))]
        for labels in word_del_targets
    ]

    # transform to tensor
    word_del_targets = torch.tensor(word_del_targets)
    return word_del_targets
コード例 #4
0
def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
    libnat = load_libnat()

    in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)

    with torch.cuda.device_of(in_tokens):
        in_tokens_list = [[t for t in s if t != padding_idx]
                          for i, s in enumerate(in_tokens.tolist())]
        out_tokens_list = [[t for t in s if t != padding_idx]
                           for i, s in enumerate(out_tokens.tolist())]

    full_labels = libnat.suggested_ed2_path(in_tokens_list, out_tokens_list,
                                            padding_idx)

    word_del_targets = [b[-1] for b in full_labels]
    word_del_targets = [
        labels + [0 for _ in range(out_seq_len - len(labels))]
        for labels in word_del_targets
    ]

    mask_inputs = [[len(c) if c[0] != padding_idx else 0 for c in a[:-1]]
                   for a in full_labels]
    mask_ins_targets = [
        mask_input[1:-1] +
        [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))]
        for mask_input in mask_inputs
    ]

    # transform to tensor
    mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
    word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
    return word_del_targets, mask_ins_targets
コード例 #5
0
def _get_ins_targets(in_tokens,
                     out_tokens,
                     padding_idx,
                     unk_idx,
                     vocab_size,
                     tau=None):
    B = in_tokens.size(0)
    T = in_tokens.size(1)
    V = vocab_size

    with torch.cuda.device_of(in_tokens):
        in_tokens_list = [[t for t in s if t != padding_idx]
                          for i, s in enumerate(in_tokens.tolist())]
        out_tokens_list = [[t for t in s if t != padding_idx]
                           for i, s in enumerate(out_tokens.tolist())]

    full_labels = libnat.suggested_ed2_path(in_tokens_list, out_tokens_list,
                                            padding_idx)
    insert_labels = [a[:-1] for a in full_labels]

    # numericalize1
    insert_label_tensors = in_tokens.new_zeros(B * (T - 1) * V).float()
    insert_index, insert_labels = zip(*[
        (w + (j + i * (T - 1)) * V, neg_scorer(k, len(label), tau))
        for i, labels in enumerate(insert_labels)
        for j, label in enumerate(labels[1:-1]) for k, w in enumerate(label)
    ])  # HACK 1:-1
    insert_index, insert_labels = [
        torch.tensor(list(a), device=in_tokens.device)
        for a in [insert_index, insert_labels]
    ]
    insert_label_tensors.scatter_(0, insert_index.long(), insert_labels)
    insert_label_tensors = insert_label_tensors.view(B, T - 1, V)

    return insert_label_tensors
コード例 #6
0
def _get_ins_targets(in_tokens,
                     out_tokens,
                     padding_idx,
                     unk_idx,
                     vocab_size,
                     tau=None):
    """Fix dtype error when tau != None. otherwise same as original.
    """

    try:
        from fairseq import libnat
    except ImportError as e:
        import sys

        sys.stderr.write(
            "ERROR: missing libnat. run `pip install --editable .`\n")
        raise e

    B = in_tokens.size(0)
    T = in_tokens.size(1)
    V = vocab_size

    with torch.cuda.device_of(in_tokens):
        in_tokens_list = [[t for t in s if t != padding_idx]
                          for i, s in enumerate(in_tokens.tolist())]
        out_tokens_list = [[t for t in s if t != padding_idx]
                           for i, s in enumerate(out_tokens.tolist())]

    full_labels = libnat.suggested_ed2_path(in_tokens_list, out_tokens_list,
                                            padding_idx)
    insert_labels = [a[:-1] for a in full_labels]

    # numericalize1
    insert_label_tensors = in_tokens.new_zeros(B * (T - 1) * V).float()
    insert_index, insert_labels = zip(*[
        (w + (j + i * (T - 1)) * V, neg_scorer(k, len(label), tau))
        for i, labels in enumerate(insert_labels)
        for j, label in enumerate(labels[1:-1]) for k, w in enumerate(label)
    ])  # HACK 1:-1
    insert_index, insert_labels = [
        torch.tensor(list(a), device=in_tokens.device)
        for a in [insert_index, insert_labels]
    ]
    insert_label_tensors.scatter_(0, insert_index.long(),
                                  insert_labels.type_as(insert_label_tensors))
    insert_label_tensors = insert_label_tensors.view(B, T - 1, V)

    return insert_label_tensors
コード例 #7
0
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
    try:
        from fairseq import libnat
    except ImportError as e:
        import sys
        sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
        raise e

    in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)

    with torch.cuda.device_of(in_tokens):
        in_tokens_list = [
            [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
        ]
        out_tokens_list = [
            [t for t in s if t != padding_idx]
            for i, s in enumerate(out_tokens.tolist())
        ]

    full_labels = libnat.suggested_ed2_path(
        in_tokens_list, out_tokens_list, padding_idx
    )
    mask_inputs = [
        [len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels
    ]

    # generate labels
    masked_tgt_masks = []
    for mask_input in mask_inputs:
        mask_label = []
        for beam_size in mask_input[1:-1]:  # HACK 1:-1
            mask_label += [0] + [1 for _ in range(beam_size)]
        masked_tgt_masks.append(
            mask_label + [0 for _ in range(out_seq_len - len(mask_label))]
        )
    mask_ins_targets = [
        mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))]
        for mask_input in mask_inputs
    ]

    # transform to tensor
    masked_tgt_masks = torch.tensor(
        masked_tgt_masks, device=out_tokens.device
    ).bool()
    mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
    masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
    return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
コード例 #8
0
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, tgt_dict, tau=None):
    try:
        from fairseq import libnat
    except ImportError as e:
        import sys
        sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
        raise e

    B = in_tokens.size(0)
    T = in_tokens.size(1)
    V = vocab_size

    with torch.cuda.device_of(in_tokens):
        in_tokens_list = [
            [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
        ]
        out_tokens_list = [
            [t for t in s if t != padding_idx]
            for i, s in enumerate(out_tokens.tolist())
        ]

    full_labels = libnat.suggested_ed2_path(
        in_tokens_list, out_tokens_list, padding_idx
    )
    insert_labels = [a[:-1] for a in full_labels]

    # numericalize1
    insert_label_tensors = in_tokens.new_zeros(B * (T - 1) * V).float()
    insert_index, insert_labels = zip(
        *[
            (w + (j + i * (T - 1)) * V, neg_scorer(w, tgt_dict, tau))
            for i, labels in enumerate(insert_labels)
            for j, label in enumerate(labels[1:-1])
            for k, w in enumerate(label)
        ]
    )  # HACK 1:-1
    insert_index, insert_labels = [
        torch.tensor(list(a), device=in_tokens.device)
        for a in [insert_index, insert_labels]
    ]
    insert_label_tensors.scatter_(0, insert_index.long(), insert_labels)
    insert_label_tensors = insert_label_tensors.view(B, T - 1, V)
    insert_label_tensors = insert_label_tensors / torch.sum(insert_label_tensors, -1).unsqueeze(-1)  # NegativeWeightScorerFreqWeight

    return insert_label_tensors
コード例 #9
0
def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
    try:
        from fairseq import libnat
    except ImportError as e:
        import sys
        sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
        raise e

    in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)

    with torch.cuda.device_of(in_tokens):
        in_tokens_list = [
            [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
        ]
        out_tokens_list = [
            [t for t in s if t != padding_idx]
            for i, s in enumerate(out_tokens.tolist())
        ]

    full_labels = libnat.suggested_ed2_path(
        in_tokens_list, out_tokens_list, padding_idx
    )

    word_del_targets = [b[-1] for b in full_labels]
    word_del_targets = [
        labels + [0 for _ in range(out_seq_len - len(labels))]
        for labels in word_del_targets
    ]

    mask_inputs = [
        [len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels
    ]
    mask_ins_targets = [
        mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))]
        for mask_input in mask_inputs
    ]

    # transform to tensor
    mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
    word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
    return word_del_targets, mask_ins_targets