Exemple #1
0
def translate_rnn(line: str, line_number: int, s2s: S2S_basic.S2S
                  or S2S_attention.S2S, src_vocab: Vocab, tgt_vocab: Vocab,
                  lang_vec: dict, device: torch.device):

    line = " ".join([src_vocab.start_token, line, src_vocab.end_token])

    line = line.split()

    lang_token = line[1]

    assert lang_token.startswith("<") and lang_token.endswith(">")

    # inputs: (input_length,)
    inputs = torch.tensor([src_vocab.get_index(token) for token in line],
                          device=device)

    # inputs: (input_length, 1)
    inputs = inputs.view(-1, 1)

    if lang_token.startswith("<") and lang_token.endswith(">"):
        # add language vector
        # input_embedding: (input_length, 1, embedding_size)
        # lang_encoding: (embedding_size, )
        lang_encoding = torch.tensor(lang_vec[lang_token], device=device)
        input_embedding = s2s.encoder.embedding(inputs) + lang_encoding
    else:
        input_embedding = s2s.encoder.embedding(inputs)
        print("line {} does not add language embedding".format(line_number))

    encoder_output, encoder_hidden_state = s2s.encoder.rnn(input_embedding)

    decoder_hidden_state = combine_bidir_hidden_state(s2s,
                                                      encoder_hidden_state)

    decoder_input = torch.tensor(
        [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device)

    max_length = (inputs.size(0) - 2) * 3

    pred_line = []

    for i in range(max_length):

        # decoder_output: (1, 1, vocab_size)
        # decoder_hidden_state: (num_layers * num_directions, batch_size, hidden_size)
        decoder_output, decoder_hidden_state = decode_batch(
            s2s, decoder_input, decoder_hidden_state, encoder_output)

        # pred: (1, 1)
        pred = torch.argmax(decoder_output, dim=2)

        if tgt_vocab.get_token(pred[0, 0].item()) == tgt_vocab.end_token:
            break

        decoder_input = pred

        pred_line.append(tgt_vocab.get_token(pred[0, 0].item()))

    return pred_line
Exemple #2
0
def convert_index_to_token(pred_list: List[List], tgt_vocab: Vocab,
                           batch_size: int, end_token_index: int):

    # pred_line: List[List] (tgt_length, batch_size)
    pred_line = []

    for j in range(batch_size):
        line = []
        for i in range(len(pred_list)):
            if pred_list[i][j] == end_token_index:
                break
            line.append(tgt_vocab.get_token(pred_list[i][j]))
        pred_line.append(line)

    return pred_line
Exemple #3
0
def beam_search_transformer(s2s: transformer.S2S, data_tensor: torch.tensor,
                            tgt_vocab: Vocab, beam_size: int,
                            device: torch.device):

    # src: (1, input_length)
    src = data_tensor
    src = src.expand(beam_size, -1)
    src_mask = s2s.make_src_mask(src)

    encoder_src = s2s.encoder(src, src_mask)

    max_length = src.size(1) * 3

    # tgt: (1, 1)
    tgt = torch.tensor([[tgt_vocab.get_index(tgt_vocab.start_token)]],
                       device=device)

    # tgt: (beam_size, 1)
    tgt = tgt.expand(beam_size, -1)
    scores = torch.zeros(beam_size, device=device)

    complete_seqs = []
    complete_seqs_scores = []

    step = 1

    while True:

        tgt_mask = s2s.make_tgt_mask(tgt)

        # output: (1 * beam_size, input_length, vocab_size)
        output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask)

        # output: (1 * beam_size, vocab_size)
        output = output[:, -1, :]

        # output: (1 * beam_size, vocab_size)
        output = F.log_softmax(output, dim=-1)

        # sub_sentence_scores: (1 * beam_size, vocab_size)
        sub_sentence_scores = output + scores.unsqueeze(1)

        if step == 1:
            pred_prob, pred_indices = sub_sentence_scores[0].topk(beam_size,
                                                                  dim=-1)
        else:
            # sub_sentence_scores: (beam_size * vocab_size)
            sub_sentence_scores = sub_sentence_scores.view(-1)
            pred_prob, pred_indices = sub_sentence_scores.topk(beam_size,
                                                               dim=-1)

        # beam_id: (beam_size, )
        beam_id = pred_indices.floor_divide(len(tgt_vocab))
        # token_id: (beam_size, )
        token_id = pred_indices % len(tgt_vocab)

        # next_tgt: (beam_size, input_length + 1)
        next_tgt = torch.cat([tgt[beam_id], token_id.unsqueeze(1)], dim=1)

        if step == max_length:
            complete_seqs.extend(next_tgt.tolist())
            complete_seqs_scores.extend(pred_prob.tolist())
            break

        complete_indices = []

        for i, indices in enumerate(token_id):

            if tgt_vocab.get_token(indices.item()) == tgt_vocab.end_token:
                complete_indices.append(i)

        if len(complete_indices) > 0:
            complete_seqs.extend(next_tgt[complete_indices].tolist())

            complete_pred_indices = beam_id[complete_indices] * len(
                tgt_vocab) + token_id[complete_indices]

            if step == 1:
                complete_seqs_scores.extend(
                    sub_sentence_scores[0][complete_pred_indices].tolist())

                if len(complete_indices) == beam_size:
                    break

                sub_sentence_scores[0][complete_pred_indices] = -1e9
                pred_prob, pred_indices = sub_sentence_scores[0].topk(
                    beam_size, dim=-1)
            else:
                complete_seqs_scores.extend(
                    sub_sentence_scores[complete_pred_indices].tolist())

                if len(complete_indices) == beam_size:
                    break

                sub_sentence_scores[complete_pred_indices] = -1e9
                pred_prob, pred_indices = sub_sentence_scores.topk(beam_size,
                                                                   dim=-1)

            # beam_id: (beam_size, )
            beam_id = pred_indices.floor_divide(len(tgt_vocab))
            # token_id: (beam_size, )
            token_id = pred_indices % len(tgt_vocab)
            # next_tgt: (beam_size, input_length + 1)
            next_tgt = torch.cat([tgt[beam_id], token_id.unsqueeze(1)], dim=1)

        step += 1

        tgt = next_tgt
        scores = pred_prob

    best_sentence_id = 0
    for i in range(len(complete_seqs_scores)):
        if complete_seqs_scores[i] > complete_seqs_scores[best_sentence_id]:
            best_sentence_id = i

    best_sentence = complete_seqs[best_sentence_id]

    best_sentence = [
        tgt_vocab.get_token(index) for index in best_sentence[1:-1]
    ]

    return best_sentence
Exemple #4
0
def beam_search_rnn(s2s: S2S_attention.S2S or S2S_basic.S2S,
                    data_tensor: torch.tensor, tgt_vocab: Vocab,
                    beam_size: int, device: torch.device):

    # batch_size == beam_size

    # inputs: (input_length, beam_size)
    inputs = data_tensor
    inputs = inputs.expand(-1, beam_size)

    encoder_output, encoder_hidden_state = s2s.encoder(inputs)

    # decoder_input: (1, beam_size)
    decoder_input = torch.tensor(
        [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device)
    decoder_input = decoder_input.expand(-1, beam_size)

    # decoder_hidden_state: (num_layers, beam_size, hidden_size)
    decoder_hidden_state = combine_bidir_hidden_state(s2s,
                                                      encoder_hidden_state)

    max_length = inputs.size(0) * 3

    scores = torch.zeros(beam_size, device=device)

    complete_seqs = []
    complete_seqs_scores = []
    step = 1

    while True:

        # output: (1, beam_size, vocab_size)
        # decoder_hidden_state: (num_layers, beam_size, hidden_size)
        output, decoder_hidden_state = decode_batch(
            s2s, decoder_input[-1].unsqueeze(0), decoder_hidden_state,
            encoder_output)

        output = F.log_softmax(output, dim=-1)

        # sub_sentence_scores: (beam_size, vocab_size)
        sub_sentence_scores = scores.unsqueeze(1) + output.squeeze(0)

        if step == 1:
            pred_prob, pred_indices = sub_sentence_scores[0].topk(beam_size,
                                                                  dim=-1)
        else:
            # sub_sentence_scores: (beam_size * vocab_size)
            sub_sentence_scores = sub_sentence_scores.view(-1)
            pred_prob, pred_indices = sub_sentence_scores.topk(beam_size,
                                                               dim=-1)

        # beam_id: (beam_size, )
        beam_id = pred_indices.floor_divide(len(tgt_vocab))

        # token_id: (beam_size, )
        token_id = pred_indices % len(tgt_vocab)

        # decoder_input[-1][beam_id]: (beam_size, )
        # next_decoder_input: (step + 1, beam_size)
        # decoder_input: (step, beam_size)
        next_decoder_input = torch.cat(
            [decoder_input[:, beam_id],
             token_id.unsqueeze(0)], dim=0)

        if step == max_length:
            complete_seqs.extend(next_decoder_input.t().tolist())
            complete_seqs_scores.extend(pred_prob.tolist())
            break

        complete_indices = []

        for i, indices in enumerate(token_id):

            if tgt_vocab.get_token(indices.item()) == tgt_vocab.end_token:
                complete_indices.append(i)

        if len(complete_indices) > 0:
            complete_seqs.extend(
                next_decoder_input[:, complete_indices].t().tolist())

            complete_pred_indices = beam_id[complete_indices] * len(
                tgt_vocab) + token_id[complete_indices]

            if step == 1:
                complete_seqs_scores.extend(
                    sub_sentence_scores[0][complete_pred_indices].tolist())

                if len(complete_pred_indices) == beam_size:
                    break

                sub_sentence_scores[0][complete_pred_indices] = -1e9
                pred_prob, pred_indices = sub_sentence_scores[0].topk(
                    beam_size, dim=-1)
            else:
                complete_seqs_scores.extend(
                    sub_sentence_scores[complete_pred_indices].tolist())

                if len(complete_pred_indices) == beam_size:
                    break

                sub_sentence_scores[complete_pred_indices] = -1e9
                pred_prob, pred_indices = sub_sentence_scores.topk(beam_size,
                                                                   dim=-1)

            beam_id = pred_indices.floor_divide(len(tgt_vocab))
            token_id = pred_indices % len(tgt_vocab)

            next_decoder_input = torch.cat(
                [decoder_input[:, beam_id],
                 token_id.unsqueeze(0)], dim=0)

        step += 1

        if isinstance(decoder_hidden_state, tuple):
            h, c = decoder_hidden_state
            h = h[:, beam_id]
            c = c[:, beam_id]
            decoder_hidden_state = (h, c)
        else:
            decoder_hidden_state = decoder_hidden_state[:, beam_id]

        decoder_input = next_decoder_input
        scores = pred_prob

    best_sentence_id = 0
    for i in range(len(complete_seqs_scores)):
        if complete_seqs_scores[i] > complete_seqs_scores[best_sentence_id]:
            best_sentence_id = i

    best_sentence = complete_seqs[best_sentence_id]

    best_sentence = [
        tgt_vocab.get_token(index) for index in best_sentence[1:-1]
    ]

    return best_sentence
Exemple #5
0
def translate_transformer(line: str, line_number: int, s2s: transformer.S2S,
                          src_vocab: Vocab, tgt_vocab: Vocab, lang_vec: dict,
                          device: torch.device):

    line = " ".join([src_vocab.start_token, line, src_vocab.end_token])

    line = line.split()

    max_length = (len(line) - 2) * 3

    lang_token = line[1]

    # inputs: (input_length, )
    src = torch.tensor([src_vocab.get_index(token) for token in line],
                       device=device)
    # inputs: (1, input_length)
    src = src.view(1, -1)

    src_mask = s2s.make_src_mask(src)

    src = s2s.encoder.token_embedding(src) * s2s.encoder.scale

    # src: (1, input_length, d_model)
    src = s2s.encoder.pos_embedding(src)

    if lang_token.startswith("<") and lang_token.endswith(">"):
        # lang_encoding: (d_model, )
        lang_encoding = torch.tensor(lang_vec[lang_token], device=device)
        src = src + lang_encoding

    else:
        print("line {} does not add language embedding".format(line_number))

    for layer in s2s.encoder.layers:
        src, self_attention = layer(src, src_mask)

    del self_attention

    encoder_src = src

    tgt = None

    pred_line = [tgt_vocab.get_index(tgt_vocab.start_token)]

    for i in range(max_length):

        if tgt is None:
            tgt = torch.tensor([pred_line], device=device)

        tgt_mask = s2s.make_tgt_mask(tgt)

        # output: (1, tgt_input_length, vocab_size)
        output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask)

        # (1, tgt_input_length)
        pred = torch.argmax(output, dim=-1)[0, -1]

        if tgt_vocab.get_token(pred.item()) == tgt_vocab.end_token:
            break

        tgt = torch.cat([tgt, pred.unsqueeze(0).unsqueeze(1)], dim=1)
        pred_line.append(pred.item())

    pred_line = [tgt_vocab.get_token(index) for index in pred_line[1:]]
    return pred_line