示例#1
0
def batch_data_variable(batch, vocab, tokenizer):
    slen, tlen = batch[0].src_len, batch[0].tgt_len
    batch_size = len(batch)
    for b in range(1, batch_size):
        cur_slen, cur_tlen = batch[b].src_len, batch[b].tgt_len
        if cur_slen > slen: slen = cur_slen
        if cur_tlen > tlen: tlen = cur_tlen
    tinst = TensorInstances(batch_size, slen, tlen)

    b = 0
    for (src_bert_indice, src_segments_id, src_piece_id), (tgt_bert_indice, tgt_segments_id, tgt_piece_id), tagid in insts_numberize(batch, vocab, tokenizer):
        tinst.tags[b] = tagid
        cur_slen, cur_tlen = len(src_bert_indice), len(tgt_bert_indice)
        tinst.src_lens[b] = cur_slen
        tinst.tgt_lens[b] = cur_tlen
        # print(src_bert_indice)
        # print('curslen', cur_slen)
        # print(tinst.src_bert_indice.size())
        tinst.src_bert_indice[b][:cur_slen] = torch.LongTensor(np.array(src_bert_indice)) # list -> array -> tensor
        tinst.tgt_bert_indice[b][:cur_tlen] = torch.LongTensor(np.array(tgt_bert_indice)) # 直接 list -> tensor 会内存泄漏
        tinst.src_masks[b][:cur_slen] = 1
        tinst.tgt_masks[b][:cur_tlen] = 1

        b += 1
    return tinst
示例#2
0
def batch_data_variable(batch, vocab):
    slen, tlen = len(batch[0].src_words), len(batch[0].tgt_words)
    batch_size = len(batch)
    for b in range(1, batch_size):
        cur_slen, cur_tlen = len(batch[b].src_words), len(batch[b].tgt_words)
        if cur_slen > slen: slen = cur_slen
        if cur_tlen > tlen: tlen = cur_tlen

    tinst = TensorInstances(batch_size, slen, tlen)
    src_words, tgt_words = [], []

    b = 0
    for src_ids, tgt_ids, tagid, inst in insts_numberize(batch, vocab):
        tinst.tags[b] = tagid
        cur_slen, cur_tlen = len(src_ids), len(tgt_ids)
        tinst.src_lens[b] = cur_slen
        tinst.tgt_lens[b] = cur_tlen

        for index in range(cur_slen):
            tinst.src_words[b, index] = src_ids[index]
            tinst.src_masks[b, index] = 1
        for index in range(cur_tlen):
            tinst.tgt_words[b, index] = tgt_ids[index]
            tinst.tgt_masks[b, index] = 1

        src_words.append(inst.src_words)
        tgt_words.append(inst.tgt_words)

        b += 1
    return tinst, src_words, tgt_words
示例#3
0
def batch_data_variable(batch, vocab):
    slen, tlen = len(batch[0].src_words), len(batch[0].tgt_words)
    batch_size = len(batch)
    for b in range(1, batch_size):
        cur_slen, cur_tlen = len(batch[b].src_words), len(batch[b].tgt_words)
        if cur_slen > slen: slen = cur_slen
        if cur_tlen > tlen: tlen = cur_tlen

    tinst = TensorInstances(batch_size, slen, tlen)

    b = 0
    for srcids, tgtids, tagid, inst in insts_numberize(batch, vocab):
        tinst.tags[b] = tagid
        cur_slen, cur_tlen = len(inst.src_words), len(inst.tgt_words)
        for index in range(cur_slen):
            tinst.src_words[b, index] = srcids[index]
            tinst.src_masks[b, index] = 1
            # tinst.src_heads[b][index].extend(inst.src_heads[index])
            # tinst.src_childs[b][index].extend(inst.src_childs[index])

        for index in range(cur_tlen):
            tinst.tgt_words[b, index] = tgtids[index]
            tinst.tgt_masks[b, index] = 1
            # tinst.tgt_heads[b][index].extend(inst.tgt_heads[index])
            # tinst.tgt_childs[b][index].extend(inst.tgt_childs[index])

        b += 1
    return tinst
示例#4
0
def batch_data_variable(batch, vocab, dep_vocab):
    slen, tlen = len(batch[0].src_words), len(batch[0].tgt_words)
    batch_size = len(batch)
    for b in range(1, batch_size):
        cur_slen, cur_tlen = len(batch[b].src_words), len(batch[b].tgt_words)
        if cur_slen > slen: slen = cur_slen
        if cur_tlen > tlen: tlen = cur_tlen

    tinst = TensorInstances(batch_size, slen, tlen)

    b = 0
    for src_ids, src_extids, tgt_ids, tgt_extids, tagid, \
        src_dep_ids, tgt_dep_ids, src_dep_extids, tgt_dep_extids \
        in insts_numberize(batch, vocab, dep_vocab):

        tinst.src_dep_words[b, 0], tinst.src_dep_extwords[b, 0], tinst.src_dep_masks[b, 0] = \
            dep_vocab.ROOT, dep_vocab.ROOT, 1
        tinst.tgt_dep_words[b, 0], tinst.tgt_dep_extwords[b, 0], tinst.tgt_dep_masks[b, 0] = \
            dep_vocab.ROOT, dep_vocab.ROOT, 1

        tinst.tags[b] = tagid
        cur_slen, cur_tlen = len(src_ids), len(tgt_ids)
        tinst.src_lens[b] = cur_slen
        tinst.tgt_lens[b] = cur_tlen

        for index in range(cur_slen):
            tinst.src_words[b, index] = src_ids[index]
            tinst.src_extwords[b, index] = src_extids[index]
            tinst.src_masks[b, index] = 1
            tinst.src_dep_words[b, index + 1] = src_dep_ids[index]
            tinst.src_dep_extwords[b, index + 1] = src_dep_extids[index]
            tinst.src_dep_masks[b, index + 1] = 1
        for index in range(cur_tlen):
            tinst.tgt_words[b, index] = tgt_ids[index]
            tinst.tgt_extwords[b, index] = tgt_extids[index]
            tinst.tgt_masks[b, index] = 1
            tinst.tgt_dep_words[b, index + 1] = tgt_dep_ids[index]
            tinst.tgt_dep_extwords[b, index + 1] = tgt_dep_extids[index]
            tinst.tgt_dep_masks[b, index + 1] = 1

        b += 1
    return tinst
示例#5
0
def batch_data_variable(batch, vocab):
    slen, tlen = len(batch[0].src_words), len(batch[0].tgt_words)
    batch_size = len(batch)
    src_lengths, tgt_lengths = [slen], [tlen]
    for b in range(1, batch_size):
        cur_slen, cur_tlen = len(batch[b].src_words), len(batch[b].tgt_words)
        if cur_slen > slen: slen = cur_slen
        if cur_tlen > tlen: tlen = cur_tlen
        src_lengths.append(cur_slen)
        tgt_lengths.append(cur_tlen)

    src_bert_token_indices, src_bert_segments_ids, src_bert_piece_ids = [], [], []
    tgt_bert_token_indices, tgt_bert_segments_ids, tgt_bert_piece_ids = [], [], []
    src_bert_lengths, tgt_bert_lengths = [], []
    src_word_indexes, tgt_word_indexes = [], []

    b, sblen, tblen = 0, 0, 0
    for tagid, inst in insts_numberize(batch, vocab):

        cur_src_words, cur_src_indexes = [], []
        for index, curword in enumerate(inst.src_words):
            items = curword.split('##')
            if len(items) != 3 or (items[0] != 'arc' and items[0] != 'pop'):
                cur_src_words.append(curword)
                cur_src_indexes.append(index)
        src_sentence = ' '.join(cur_src_words)

        cur_tgt_words, cur_tgt_indexes = [], []
        for index, curword in enumerate(inst.tgt_words):
            items = curword.split('##')
            if len(items) != 3 or (items[0] != 'arc' and items[0] != 'pop'):
                cur_tgt_words.append(curword)
                cur_tgt_indexes.append(index)
        tgt_sentence = ' '.join(cur_tgt_words)

        src_word_indexes.append(cur_src_indexes)
        tgt_word_indexes.append(cur_tgt_indexes)

        src_bert_indice, src_segments_id, src_piece_id = vocab.bert_ids(src_sentence)
        cur_src_length = len(src_bert_indice)
        src_bert_lengths.append(cur_src_length)
        if cur_src_length > sblen: sblen = cur_src_length
        src_bert_token_indices.append(src_bert_indice)
        src_bert_segments_ids.append(src_segments_id)
        src_bert_piece_ids.append(src_piece_id)

        tgt_bert_indice, tgt_segments_id, tgt_piece_id = vocab.bert_ids(tgt_sentence)
        cur_tgt_length = len(tgt_bert_indice)
        tgt_bert_lengths.append(cur_tgt_length)
        if cur_tgt_length > tblen: tblen = cur_tgt_length
        tgt_bert_token_indices.append(tgt_bert_indice)
        tgt_bert_segments_ids.append(tgt_segments_id)
        tgt_bert_piece_ids.append(tgt_piece_id)

        b += 1

    tinst = TensorInstances(batch_size, slen, tlen, sblen, tblen)

    b, shift_pos = 0, 1  # remove the first token
    for src_ids, src_acids, tgt_ids, tgt_acids, tagid, inst in insts_numberize(batch, vocab):
        for index in range(src_bert_lengths[b]):
            tinst.src_bert_indices[b, index] = src_bert_token_indices[b][index]
            tinst.src_bert_segments[b, index] = src_bert_segments_ids[b][index]

        for src_sindex in range(src_lengths[b]):
            avg_score = 1.0 / len(src_bert_piece_ids[b][src_sindex + shift_pos])
            tinst.src_masks[b, src_sindex] = 1
            for src_tindex in src_bert_piece_ids[b][src_sindex + shift_pos]:
                tinst.src_bert_pieces[b, src_sindex, src_tindex] = avg_score

        for index in range(tgt_bert_lengths[b]):
            tinst.tgt_bert_indices[b, index] = tgt_bert_token_indices[b][index]
            tinst.tgt_bert_segments[b, index] = tgt_bert_segments_ids[b][index]

        for tgt_sindex in range(tgt_lengths[b]):
            avg_score = 1.0 / len(tgt_bert_piece_ids[b][tgt_sindex + shift_pos])
            tinst.tgt_masks[b, tgt_sindex] = 1
            for tgt_tindex in tgt_bert_piece_ids[b][tgt_sindex + shift_pos]:
                tinst.tgt_bert_pieces[b, tgt_sindex, tgt_tindex] = avg_score

        tinst.tags[b] = tagid
        cur_slen, cur_tlen = len(src_ids), len(tgt_ids)
        tinst.src_lens[b] = cur_slen
        tinst.tgt_lens[b] = cur_tlen

        for index in range(cur_slen):
            tinst.src_actions[b, index] = src_acids[index]
            tinst.src_masks[b, index] = 1
        for index in range(cur_tlen):
            tinst.tgt_actions[b, index] = tgt_acids[index]
            tinst.tgt_masks[b, index] = 1

        b += 1

    return tinst, src_word_indexes, tgt_word_indexes
示例#6
0
def batch_data_variable(batch, vocab):
    slen, tlen = len(batch[0].src_words), len(batch[0].tgt_words)
    batch_size = len(batch)
    src_lengths, tgt_lengths = [slen], [tlen]
    for b in range(1, batch_size):
        cur_slen, cur_tlen = len(batch[b].src_words), len(batch[b].tgt_words)
        if cur_slen > slen: slen = cur_slen
        if cur_tlen > tlen: tlen = cur_tlen
        src_lengths.append(cur_slen)
        tgt_lengths.append(cur_tlen)

    src_bert_token_indices, src_bert_segments_ids, src_bert_piece_ids = [], [], []
    tgt_bert_token_indices, tgt_bert_segments_ids, tgt_bert_piece_ids = [], [], []
    src_bert_lengths, tgt_bert_lengths = [], []

    b, sblen, tblen = 0, 0, 0
    for tagid, inst in insts_numberize(batch, vocab):
        src_bert_indice, src_segments_id, src_piece_id = vocab.bert_ids(
            inst.src_sentence)
        cur_src_length = len(src_bert_indice)
        src_bert_lengths.append(cur_src_length)
        if cur_src_length > sblen: sblen = cur_src_length
        src_bert_token_indices.append(src_bert_indice)
        src_bert_segments_ids.append(src_segments_id)
        src_bert_piece_ids.append(src_piece_id)

        tgt_bert_indice, tgt_segments_id, tgt_piece_id = vocab.bert_ids(
            inst.tgt_sentence)
        cur_tgt_length = len(tgt_bert_indice)
        tgt_bert_lengths.append(cur_tgt_length)
        if cur_tgt_length > tblen: tblen = cur_tgt_length
        tgt_bert_token_indices.append(tgt_bert_indice)
        tgt_bert_segments_ids.append(tgt_segments_id)
        tgt_bert_piece_ids.append(tgt_piece_id)

        b += 1

    tinst = TensorInstances(batch_size, slen, tlen, sblen, tblen)

    b, shift_pos = 0, 1  # remove the first token
    for src_relids, tgt_relids, tagid, inst in insts_numberize(batch, vocab):
        for index in range(src_bert_lengths[b]):
            tinst.src_bert_indices[b, index] = src_bert_token_indices[b][index]
            tinst.src_bert_segments[b, index] = src_bert_segments_ids[b][index]

        for src_sindex in range(src_lengths[b]):
            avg_score = 1.0 / len(
                src_bert_piece_ids[b][src_sindex + shift_pos])
            tinst.src_masks[b, src_sindex] = 1
            for src_tindex in src_bert_piece_ids[b][src_sindex + shift_pos]:
                tinst.src_bert_pieces[b, src_sindex, src_tindex] = avg_score

        for index in range(tgt_bert_lengths[b]):
            tinst.tgt_bert_indices[b, index] = tgt_bert_token_indices[b][index]
            tinst.tgt_bert_segments[b, index] = tgt_bert_segments_ids[b][index]

        for tgt_sindex in range(tgt_lengths[b]):
            avg_score = 1.0 / len(
                tgt_bert_piece_ids[b][tgt_sindex + shift_pos])
            tinst.tgt_masks[b, tgt_sindex] = 1
            for tgt_tindex in tgt_bert_piece_ids[b][tgt_sindex + shift_pos]:
                tinst.tgt_bert_pieces[b, tgt_sindex, tgt_tindex] = avg_score

        tinst.tags[b] = tagid
        cur_slen, cur_tlen = len(src_relids), len(tgt_relids)
        tinst.src_lens[b] = cur_slen
        tinst.tgt_lens[b] = cur_tlen
        tinst.src_heads.append(inst.src_heads)
        tinst.tgt_heads.append(inst.tgt_heads)

        for index in range(cur_slen):
            tinst.src_rels[b, index] = src_relids[index]
            tinst.src_masks[b, index] = 1
        for index in range(cur_tlen):
            tinst.tgt_rels[b, index] = tgt_relids[index]
            tinst.tgt_masks[b, index] = 1

        b += 1
    return tinst
示例#7
0
def batch_data_variable(batch, vocab):
    slen, tlen = len(batch[0].src_words), len(batch[0].tgt_words)
    batch_size = len(batch)
    for b in range(1, batch_size):
        cur_slen, cur_tlen = len(batch[b].src_words), len(batch[b].tgt_words)
        if cur_slen > slen: slen = cur_slen
        if cur_tlen > tlen: tlen = cur_tlen

    tinst = TensorInstances(batch_size, slen, tlen)
    src_words, tgt_words = [], []
    src_word_indexes, tgt_word_indexes = [], []

    b = 0
    for src_ids, src_acids, tgt_ids, tgt_acids, tagid, inst in insts_numberize(
            batch, vocab):
        tinst.tags[b] = tagid
        cur_slen, cur_tlen = len(src_ids), len(tgt_ids)
        tinst.src_lens[b] = cur_slen
        tinst.tgt_lens[b] = cur_tlen

        for index in range(cur_slen):
            tinst.src_words[b, index] = src_ids[index]
            tinst.src_actions[b, index] = src_acids[index]
            tinst.src_masks[b, index] = 1
        for index in range(cur_tlen):
            tinst.tgt_words[b, index] = tgt_ids[index]
            tinst.tgt_actions[b, index] = tgt_acids[index]
            tinst.tgt_masks[b, index] = 1

        cur_src_words, cur_src_indexes = [], []
        for index, curword in enumerate(inst.src_words):
            items = curword.split('##')
            if len(items) != 3 or (items[0] != 'arc' and items[0] != 'pop'):
                cur_src_words.append(curword)
                cur_src_indexes.append(index)

        cur_tgt_words, cur_tgt_indexes = [], []
        for index, curword in enumerate(inst.tgt_words):
            items = curword.split('##')
            if len(items) != 3 or (items[0] != 'arc' and items[0] != 'pop'):
                cur_tgt_words.append(curword)
                cur_tgt_indexes.append(index)

        src_words.append(cur_src_words)
        tgt_words.append(cur_tgt_words)
        src_word_indexes.append(cur_src_indexes)
        tgt_word_indexes.append(cur_tgt_indexes)

        b += 1
    return tinst, src_words, src_word_indexes, tgt_words, tgt_word_indexes