Ejemplo n.º 1
0
def greedy_decoding_transformer(s2s: transformer.S2S,
                                data_tensor: torch.tensor,
                                tgt_vocab: Vocab,
                                device: torch.device,
                                tgt_prefix: List[str] = None):

    # src: (batch_size, input_length)
    src = data_tensor
    src_mask = s2s.make_src_mask(src)

    batch_size = src.size(0)

    encoder_src = s2s.encoder(src, src_mask)

    tgt = torch.tensor([[tgt_vocab.get_index(tgt_vocab.start_token)]],
                       device=device)
    tgt = tgt.expand(batch_size, -1)

    # pred_list: List[List] (tgt_length, batch_size)
    pred_list = []

    if tgt_prefix is not None:
        # tgt_prefix_tensor: (batch_size, )
        tgt_prefix = [
            tgt_vocab.get_index(prefix_token) for prefix_token in tgt_prefix
        ]
        tgt_prefix_tensor = torch.tensor(tgt_prefix, device=device)
        # tgt_prefix_tensor: (batch_size, 1)
        tgt_prefix_tensor = tgt_prefix_tensor.unsqueeze(1)
        # tgt: (batch_size, 2)
        tgt = torch.cat([tgt, tgt_prefix_tensor], dim=1)
        pred_list.append(tgt_prefix)

    max_length = src.size(1) * 3

    end_token_index = tgt_vocab.get_index(tgt_vocab.end_token)

    for i in range(0 if tgt_prefix is None else 1, max_length):

        # tgt: (batch_size, i + 1)
        tgt_mask = s2s.make_tgt_mask(tgt)

        # output: (batch_size, input_length, vocab_size)
        output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask)
        # output: (batch_size, vocab_size)
        output = output[:, -1, :]

        # pred: (batch_size, )
        pred = torch.argmax(output, dim=-1)

        if torch.all(pred == end_token_index).item():
            break

        tgt = torch.cat([tgt, pred.unsqueeze(1)], dim=1)

        pred_list.append(pred.tolist())

    return convert_index_to_token(pred_list, tgt_vocab, batch_size,
                                  end_token_index)
Ejemplo n.º 2
0
def translate_rnn(line: str, line_number: int, s2s: S2S_basic.S2S
                  or S2S_attention.S2S, src_vocab: Vocab, tgt_vocab: Vocab,
                  lang_vec: dict, device: torch.device):

    line = " ".join([src_vocab.start_token, line, src_vocab.end_token])

    line = line.split()

    lang_token = line[1]

    assert lang_token.startswith("<") and lang_token.endswith(">")

    # inputs: (input_length,)
    inputs = torch.tensor([src_vocab.get_index(token) for token in line],
                          device=device)

    # inputs: (input_length, 1)
    inputs = inputs.view(-1, 1)

    if lang_token.startswith("<") and lang_token.endswith(">"):
        # add language vector
        # input_embedding: (input_length, 1, embedding_size)
        # lang_encoding: (embedding_size, )
        lang_encoding = torch.tensor(lang_vec[lang_token], device=device)
        input_embedding = s2s.encoder.embedding(inputs) + lang_encoding
    else:
        input_embedding = s2s.encoder.embedding(inputs)
        print("line {} does not add language embedding".format(line_number))

    encoder_output, encoder_hidden_state = s2s.encoder.rnn(input_embedding)

    decoder_hidden_state = combine_bidir_hidden_state(s2s,
                                                      encoder_hidden_state)

    decoder_input = torch.tensor(
        [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device)

    max_length = (inputs.size(0) - 2) * 3

    pred_line = []

    for i in range(max_length):

        # decoder_output: (1, 1, vocab_size)
        # decoder_hidden_state: (num_layers * num_directions, batch_size, hidden_size)
        decoder_output, decoder_hidden_state = decode_batch(
            s2s, decoder_input, decoder_hidden_state, encoder_output)

        # pred: (1, 1)
        pred = torch.argmax(decoder_output, dim=2)

        if tgt_vocab.get_token(pred[0, 0].item()) == tgt_vocab.end_token:
            break

        decoder_input = pred

        pred_line.append(tgt_vocab.get_token(pred[0, 0].item()))

    return pred_line
Ejemplo n.º 3
0
def convert_data_to_index(data: List[str], vocab: Vocab):
    data2index = []

    for sentence in data:
        sentence = " ".join([vocab.start_token, sentence, vocab.end_token])
        data2index.append(
            [vocab.get_index(token) for token in sentence.split()])

    return data2index
Ejemplo n.º 4
0
def greedy_decoding_rnn(s2s: S2S_basic.S2S or S2S_attention.S2S,
                        data_tensor: torch.tensor, tgt_vocab: Vocab,
                        device: torch.device):

    # inputs: (input_length, batch_size)
    inputs = data_tensor

    batch_size = inputs.size(1)

    encoder_output, encoder_hidden_state = s2s.encoder(inputs)

    decoder_hidden_state = combine_bidir_hidden_state(s2s,
                                                      encoder_hidden_state)

    decoder_input = torch.tensor(
        [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device)
    decoder_input = decoder_input.expand(-1, batch_size)

    max_length = inputs.size(0) * 3

    # pred_list: List[List] (tgt_length, batch_size)
    pred_list = []

    end_token_index = tgt_vocab.get_index(tgt_vocab.end_token)

    for i in range(max_length):

        # decoder_output: (1, batch_size, vocab_size)
        # decoder_hidden_state: (num_layers * num_directions, batch_size, hidden_size)
        decoder_output, decoder_hidden_state = decode_batch(
            s2s, decoder_input, decoder_hidden_state, encoder_output)

        # pred: (1, batch_size)
        pred = torch.argmax(decoder_output, dim=2)

        if torch.all(pred == end_token_index).item():
            break

        decoder_input = pred

        pred_list.append(pred.squeeze(0).tolist())

    return convert_index_to_token(pred_list, tgt_vocab, batch_size,
                                  end_token_index)
Ejemplo n.º 5
0
def load_corpus_data(data_path,
                     language_name,
                     start_token,
                     end_token,
                     mask_token,
                     vocab_path,
                     rebuild_vocab,
                     unk="UNK",
                     threshold=0):
    if rebuild_vocab:
        v = Vocab(language_name,
                  start_token,
                  end_token,
                  mask_token,
                  threshold=threshold)

    corpus = []

    with open(data_path) as f:

        data = f.read().strip().split("\n")

        for line in data:
            line = line.strip()
            line = " ".join([start_token, line, end_token])

            if rebuild_vocab:
                v.add_sentence(line)

            corpus.append(line)

    data2index = []

    if rebuild_vocab:
        v.add_unk(unk)
        v.save(vocab_path)
    else:
        v = Vocab.load(vocab_path)

    for line in corpus:
        data2index.append([v.get_index(token) for token in line.split()])

    return data2index, v
Ejemplo n.º 6
0
def beam_search_transformer(s2s: transformer.S2S, data_tensor: torch.tensor,
                            tgt_vocab: Vocab, beam_size: int,
                            device: torch.device):

    # src: (1, input_length)
    src = data_tensor
    src = src.expand(beam_size, -1)
    src_mask = s2s.make_src_mask(src)

    encoder_src = s2s.encoder(src, src_mask)

    max_length = src.size(1) * 3

    # tgt: (1, 1)
    tgt = torch.tensor([[tgt_vocab.get_index(tgt_vocab.start_token)]],
                       device=device)

    # tgt: (beam_size, 1)
    tgt = tgt.expand(beam_size, -1)
    scores = torch.zeros(beam_size, device=device)

    complete_seqs = []
    complete_seqs_scores = []

    step = 1

    while True:

        tgt_mask = s2s.make_tgt_mask(tgt)

        # output: (1 * beam_size, input_length, vocab_size)
        output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask)

        # output: (1 * beam_size, vocab_size)
        output = output[:, -1, :]

        # output: (1 * beam_size, vocab_size)
        output = F.log_softmax(output, dim=-1)

        # sub_sentence_scores: (1 * beam_size, vocab_size)
        sub_sentence_scores = output + scores.unsqueeze(1)

        if step == 1:
            pred_prob, pred_indices = sub_sentence_scores[0].topk(beam_size,
                                                                  dim=-1)
        else:
            # sub_sentence_scores: (beam_size * vocab_size)
            sub_sentence_scores = sub_sentence_scores.view(-1)
            pred_prob, pred_indices = sub_sentence_scores.topk(beam_size,
                                                               dim=-1)

        # beam_id: (beam_size, )
        beam_id = pred_indices.floor_divide(len(tgt_vocab))
        # token_id: (beam_size, )
        token_id = pred_indices % len(tgt_vocab)

        # next_tgt: (beam_size, input_length + 1)
        next_tgt = torch.cat([tgt[beam_id], token_id.unsqueeze(1)], dim=1)

        if step == max_length:
            complete_seqs.extend(next_tgt.tolist())
            complete_seqs_scores.extend(pred_prob.tolist())
            break

        complete_indices = []

        for i, indices in enumerate(token_id):

            if tgt_vocab.get_token(indices.item()) == tgt_vocab.end_token:
                complete_indices.append(i)

        if len(complete_indices) > 0:
            complete_seqs.extend(next_tgt[complete_indices].tolist())

            complete_pred_indices = beam_id[complete_indices] * len(
                tgt_vocab) + token_id[complete_indices]

            if step == 1:
                complete_seqs_scores.extend(
                    sub_sentence_scores[0][complete_pred_indices].tolist())

                if len(complete_indices) == beam_size:
                    break

                sub_sentence_scores[0][complete_pred_indices] = -1e9
                pred_prob, pred_indices = sub_sentence_scores[0].topk(
                    beam_size, dim=-1)
            else:
                complete_seqs_scores.extend(
                    sub_sentence_scores[complete_pred_indices].tolist())

                if len(complete_indices) == beam_size:
                    break

                sub_sentence_scores[complete_pred_indices] = -1e9
                pred_prob, pred_indices = sub_sentence_scores.topk(beam_size,
                                                                   dim=-1)

            # beam_id: (beam_size, )
            beam_id = pred_indices.floor_divide(len(tgt_vocab))
            # token_id: (beam_size, )
            token_id = pred_indices % len(tgt_vocab)
            # next_tgt: (beam_size, input_length + 1)
            next_tgt = torch.cat([tgt[beam_id], token_id.unsqueeze(1)], dim=1)

        step += 1

        tgt = next_tgt
        scores = pred_prob

    best_sentence_id = 0
    for i in range(len(complete_seqs_scores)):
        if complete_seqs_scores[i] > complete_seqs_scores[best_sentence_id]:
            best_sentence_id = i

    best_sentence = complete_seqs[best_sentence_id]

    best_sentence = [
        tgt_vocab.get_token(index) for index in best_sentence[1:-1]
    ]

    return best_sentence
Ejemplo n.º 7
0
def beam_search_rnn(s2s: S2S_attention.S2S or S2S_basic.S2S,
                    data_tensor: torch.tensor, tgt_vocab: Vocab,
                    beam_size: int, device: torch.device):

    # batch_size == beam_size

    # inputs: (input_length, beam_size)
    inputs = data_tensor
    inputs = inputs.expand(-1, beam_size)

    encoder_output, encoder_hidden_state = s2s.encoder(inputs)

    # decoder_input: (1, beam_size)
    decoder_input = torch.tensor(
        [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device)
    decoder_input = decoder_input.expand(-1, beam_size)

    # decoder_hidden_state: (num_layers, beam_size, hidden_size)
    decoder_hidden_state = combine_bidir_hidden_state(s2s,
                                                      encoder_hidden_state)

    max_length = inputs.size(0) * 3

    scores = torch.zeros(beam_size, device=device)

    complete_seqs = []
    complete_seqs_scores = []
    step = 1

    while True:

        # output: (1, beam_size, vocab_size)
        # decoder_hidden_state: (num_layers, beam_size, hidden_size)
        output, decoder_hidden_state = decode_batch(
            s2s, decoder_input[-1].unsqueeze(0), decoder_hidden_state,
            encoder_output)

        output = F.log_softmax(output, dim=-1)

        # sub_sentence_scores: (beam_size, vocab_size)
        sub_sentence_scores = scores.unsqueeze(1) + output.squeeze(0)

        if step == 1:
            pred_prob, pred_indices = sub_sentence_scores[0].topk(beam_size,
                                                                  dim=-1)
        else:
            # sub_sentence_scores: (beam_size * vocab_size)
            sub_sentence_scores = sub_sentence_scores.view(-1)
            pred_prob, pred_indices = sub_sentence_scores.topk(beam_size,
                                                               dim=-1)

        # beam_id: (beam_size, )
        beam_id = pred_indices.floor_divide(len(tgt_vocab))

        # token_id: (beam_size, )
        token_id = pred_indices % len(tgt_vocab)

        # decoder_input[-1][beam_id]: (beam_size, )
        # next_decoder_input: (step + 1, beam_size)
        # decoder_input: (step, beam_size)
        next_decoder_input = torch.cat(
            [decoder_input[:, beam_id],
             token_id.unsqueeze(0)], dim=0)

        if step == max_length:
            complete_seqs.extend(next_decoder_input.t().tolist())
            complete_seqs_scores.extend(pred_prob.tolist())
            break

        complete_indices = []

        for i, indices in enumerate(token_id):

            if tgt_vocab.get_token(indices.item()) == tgt_vocab.end_token:
                complete_indices.append(i)

        if len(complete_indices) > 0:
            complete_seqs.extend(
                next_decoder_input[:, complete_indices].t().tolist())

            complete_pred_indices = beam_id[complete_indices] * len(
                tgt_vocab) + token_id[complete_indices]

            if step == 1:
                complete_seqs_scores.extend(
                    sub_sentence_scores[0][complete_pred_indices].tolist())

                if len(complete_pred_indices) == beam_size:
                    break

                sub_sentence_scores[0][complete_pred_indices] = -1e9
                pred_prob, pred_indices = sub_sentence_scores[0].topk(
                    beam_size, dim=-1)
            else:
                complete_seqs_scores.extend(
                    sub_sentence_scores[complete_pred_indices].tolist())

                if len(complete_pred_indices) == beam_size:
                    break

                sub_sentence_scores[complete_pred_indices] = -1e9
                pred_prob, pred_indices = sub_sentence_scores.topk(beam_size,
                                                                   dim=-1)

            beam_id = pred_indices.floor_divide(len(tgt_vocab))
            token_id = pred_indices % len(tgt_vocab)

            next_decoder_input = torch.cat(
                [decoder_input[:, beam_id],
                 token_id.unsqueeze(0)], dim=0)

        step += 1

        if isinstance(decoder_hidden_state, tuple):
            h, c = decoder_hidden_state
            h = h[:, beam_id]
            c = c[:, beam_id]
            decoder_hidden_state = (h, c)
        else:
            decoder_hidden_state = decoder_hidden_state[:, beam_id]

        decoder_input = next_decoder_input
        scores = pred_prob

    best_sentence_id = 0
    for i in range(len(complete_seqs_scores)):
        if complete_seqs_scores[i] > complete_seqs_scores[best_sentence_id]:
            best_sentence_id = i

    best_sentence = complete_seqs[best_sentence_id]

    best_sentence = [
        tgt_vocab.get_token(index) for index in best_sentence[1:-1]
    ]

    return best_sentence
Ejemplo n.º 8
0
def translate_transformer(line: str, line_number: int, s2s: transformer.S2S,
                          src_vocab: Vocab, tgt_vocab: Vocab, lang_vec: dict,
                          device: torch.device):

    line = " ".join([src_vocab.start_token, line, src_vocab.end_token])

    line = line.split()

    max_length = (len(line) - 2) * 3

    lang_token = line[1]

    # inputs: (input_length, )
    src = torch.tensor([src_vocab.get_index(token) for token in line],
                       device=device)
    # inputs: (1, input_length)
    src = src.view(1, -1)

    src_mask = s2s.make_src_mask(src)

    src = s2s.encoder.token_embedding(src) * s2s.encoder.scale

    # src: (1, input_length, d_model)
    src = s2s.encoder.pos_embedding(src)

    if lang_token.startswith("<") and lang_token.endswith(">"):
        # lang_encoding: (d_model, )
        lang_encoding = torch.tensor(lang_vec[lang_token], device=device)
        src = src + lang_encoding

    else:
        print("line {} does not add language embedding".format(line_number))

    for layer in s2s.encoder.layers:
        src, self_attention = layer(src, src_mask)

    del self_attention

    encoder_src = src

    tgt = None

    pred_line = [tgt_vocab.get_index(tgt_vocab.start_token)]

    for i in range(max_length):

        if tgt is None:
            tgt = torch.tensor([pred_line], device=device)

        tgt_mask = s2s.make_tgt_mask(tgt)

        # output: (1, tgt_input_length, vocab_size)
        output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask)

        # (1, tgt_input_length)
        pred = torch.argmax(output, dim=-1)[0, -1]

        if tgt_vocab.get_token(pred.item()) == tgt_vocab.end_token:
            break

        tgt = torch.cat([tgt, pred.unsqueeze(0).unsqueeze(1)], dim=1)
        pred_line.append(pred.item())

    pred_line = [tgt_vocab.get_token(index) for index in pred_line[1:]]
    return pred_line