Exemplo n.º 1
0
def beam_search(src, model, SRC, TRG, opt):

    outputs, e_outputs, log_scores = init_vars(src, model, SRC, TRG, opt)
    eos_tok = TRG.vocab.stoi['<eos>']
    src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
    ind = None
    for i in range(2, opt.max_len):

        trg_mask = nopeak_mask(i, opt)

        out = model.out(
            model.decoder(outputs[:, :i], e_outputs, src_mask, trg_mask))

        out = F.softmax(out, dim=-1)

        outputs, log_scores = k_best_outputs(outputs, out, log_scores, i,
                                             opt.k)

        ones = (outputs == eos_tok).nonzero(
        )  # Occurrences of end symbols for all input sentences.
        sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).cuda()
        for vec in ones:
            i = vec[0]
            if sentence_lengths[
                    i] == 0:  # First end symbol has not been found yet
                sentence_lengths[i] = vec[1]  # Position of first end symbol

        num_finished_sentences = len([s for s in sentence_lengths if s > 0])

        if num_finished_sentences == opt.k:
            alpha = 0.7
            div = 1 / (sentence_lengths.type_as(log_scores)**alpha)
            _, ind = torch.max(log_scores * div, 1)
            ind = ind.data[0]
            break

    if ind is None:
        length = (outputs[0] == eos_tok).nonzero()[0]
        return ' '.join([TRG.vocab.itos[tok] for tok in outputs[0][1:length]])

    else:
        length = (outputs[ind] == eos_tok).nonzero()[0]
        return ' '.join(
            [TRG.vocab.itos[tok] for tok in outputs[ind][1:length]])
Exemplo n.º 2
0
def init_vars(src, model, SRC, TRG, opt):
    init_tok = TRG.vocab.stoi['<sos>']
    src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
    # this is the output from the encoder
    e_output = model.encoder(src, src_mask)
    # this is initializing the outputs
    outputs = torch.LongTensor([[init_tok]])
    if opt.device == 0:
        outputs = outputs.cuda()

    trg_mask = nopeak_mask(1, opt)
    src_mask = src_mask.cuda()
    trg_mask = trg_mask.cuda()
    outputs = outputs.cuda()
    e_output = e_output.cuda()

    out = model.out(model.decoder(outputs, e_output, src_mask, trg_mask))
    out = F.softmax(out, dim=-1)

    probs, ix = out[:, -1].data.topk(opt.k)
    preds_token_ids = ix.view(ix.size(0), -1)
    pred_strings = [
        ' '.join([TRG.vocab.itos[ind] for ind in ex]) for ex in preds_token_ids
    ]

    # print (pred_strings)

    log_scores = torch.Tensor([math.log(prob)
                               for prob in probs.data[0]]).unsqueeze(0)

    outputs = torch.zeros(opt.k, opt.max_len).long()
    if opt.device == 0:
        outputs = outputs.cuda()
    outputs[:, 0] = init_tok
    outputs[:, 1] = ix[0]

    e_outputs = torch.zeros(opt.k, e_output.size(-2), e_output.size(-1))
    if opt.device == 0:
        e_outputs = e_outputs.cuda()
    e_outputs[:, :] = e_output[0]

    return outputs, e_outputs, log_scores
Exemplo n.º 3
0
def init_vars(src, model, SRC, TRG, opt, recursive=False, query_mask=None):
    init_tok = TRG.vocab.stoi['<sos>']
    if recursive:
        src_mask = None
        e_holder = []
        for i, c in enumerate(src):
            e_holder.append(
                model.encoder([c[0], [e_holder[x] for x in c[1]]], src_mask))
        e_output = e_holder[-1]
    elif query_mask is not None:
        src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
        e_output = model.encoder(src, src_mask, query_mask=query_mask)
    else:
        src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
        e_output = model.encoder(src, src_mask)

    outputs = torch.LongTensor([[init_tok]])
    if opt.device == 0:
        outputs = outputs.cuda()

    trg_mask = nopeak_mask(1, opt)

    out = model.out(model.decoder(outputs, e_output, src_mask, trg_mask))
    out = F.softmax(out, dim=-1)

    probs, ix = out[:, -1].data.topk(opt.k)
    log_scores = torch.Tensor([math.log(prob)
                               for prob in probs.data[0]]).unsqueeze(0)

    outputs = torch.zeros(opt.k, opt.max_len).long()
    if opt.device == 0:
        outputs = outputs.cuda()
    outputs[:, 0] = init_tok
    outputs[:, 1] = ix[0]

    e_outputs = torch.zeros(opt.k, e_output.size(-2), e_output.size(-1))
    if opt.device == 0:
        e_outputs = e_outputs.cuda()
    e_outputs[:, :] = e_output[0]

    return outputs, e_outputs, log_scores
Exemplo n.º 4
0
def beam_search(src, model, SRC, TRG, opt, recursive=False, query_mask=None):
    outputs, e_outputs, log_scores = init_vars(src,
                                               model,
                                               SRC,
                                               TRG,
                                               opt,
                                               query_mask=query_mask)
    eos_tok = TRG.vocab.stoi['<eos>']

    if recursive:
        src_mask = None
    else:
        src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
    ind = None
    for i in range(2, opt.max_len):
        trg_mask = nopeak_mask(i, opt)
        out = model.out(
            model.decoder(outputs[:, :i], e_outputs, src_mask, trg_mask))
        out = F.softmax(out, dim=-1)

        outputs, log_scores = k_best_outputs(outputs, out, log_scores, i,
                                             opt.k)
        ones = torch.nonzero(torch.eq(outputs, eos_tok), as_tuple=False)
        sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).cuda()
        for vec in ones:
            i = vec[0]
            if sentence_lengths[
                    i] == 0:  # First end symbol has not been found yet
                sentence_lengths[i] = vec[1]  # Position of first end symbol

        num_finished_sentences = len([s for s in sentence_lengths if s > 0])

        if num_finished_sentences == opt.k:
            alpha = 0.7
            div = 1 / (sentence_lengths.type_as(log_scores)**alpha)
            _, ind = torch.max(log_scores * div, 1)
            ind = ind.data[0]
            break

    if ind is None:
        try:
            length = torch.nonzero(torch.eq(outputs[0], eos_tok),
                                   as_tuple=False)[0]
        except IndexError:
            length = opt.max_len
        built_string = ''
        for tok in outputs[0][1:length]:
            temp_tok = TRG.vocab.itos[tok]
            built_string += ' ' + temp_tok
        return built_string.strip()

    else:
        try:
            length = torch.nonzero(torch.eq(outputs[ind], eos_tok),
                                   as_tuple=False)[0]
        except IndexError:
            length = opt.max_len
        built_string = ''
        for tok in outputs[ind][1:length]:
            temp_tok = TRG.vocab.itos[tok]
            built_string += ' ' + temp_tok
        return built_string.strip()
Exemplo n.º 5
0
def beam_search(src, model, SRC, TRG, opt):
    outputs, e_outputs, log_scores = init_vars(src, model, SRC, TRG, opt)
    eos_tok = TRG.vocab.stoi['<eos>']
    src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
    ind = None
    query = {}
    query_tokens = []
    for i in range(2, opt.max_len):

        trg_mask = nopeak_mask(i, opt)
        src_mask = src_mask.cuda()
        trg_mask = trg_mask.cuda()

        out = model.out(
            model.decoder(outputs[:, :i], e_outputs, src_mask, trg_mask))
        # print (outputs.size())
        # print (out.size())
        out = F.softmax(out, dim=-1)

        # print("output data shape")
        # print(out.data.shape)

        outputs, log_scores, pred_strings, pred_strings_dict = k_best_outputs(
            outputs, out, log_scores, i, opt.k, TRG)

        #         This part is another way of forming the query dictionary
        for pred_string in pred_strings:
            pred_string_splitted = pred_string.split()
            for st in pred_string_splitted:
                query.setdefault(st, 1.0)
                query[st] = query[st] + 1
            query_tokens.extend(pred_string_splitted)

        for term in pred_strings_dict:
            if term in query:
                if pred_strings_dict[term] > query[term]:
                    query[term] = pred_strings_dict[term]
            else:
                query[term] = pred_strings_dict[term]

        if (outputs == eos_tok).nonzero().size(0) == opt.k:
            alpha = 0.7
            div = 1 / ((outputs == eos_tok).nonzero()[:, 1].type_as(log_scores)
                       **alpha)
            _, ind = torch.max(log_scores * div, 1)
            ind = ind.data[0]
            break
    # print("query")
    # print(query)
    # if ind is None:
    #     length = (outputs[0] == eos_tok).nonzero()[0]
    #     return ' '.join([TRG.vocab.itos[tok] for tok in outputs[0][1:length]])
    #
    # else:
    #     length = (outputs[ind] == eos_tok).nonzero()[0]
    # return ' '.join([TRG.vocab.itos[tok] for tok in outputs[ind][1:length]])

    if ind is None:
        query_list = []
        # print("value of k is " + str(opt.k))
        for i in np.arange(opt.k):
            if eos_tok in outputs[i]:
                length = (outputs[i] == eos_tok).nonzero()[0]
            else:
                length = opt.max_len
            query_list.append(' '.join(
                [TRG.vocab.itos[tok] for tok in outputs[i][1:length]]))
        return query_list, query, query_tokens

        # if (outputs[0]==eos_tok).nonzero().size(0) >= 1:
        #     length = (outputs[0]==eos_tok).nonzero()[0]
        #     return ' '.join([TRG.vocab.itos[tok] for tok in outputs[0][1:length]])
        # else:
        #     return ' '

    else:
        # if (outputs[ind] == eos_tok).nonzero().size(0) >= 1:
        #     length = (outputs[ind]==eos_tok).nonzero()[0]
        #     return ' '.join([TRG.vocab.itos[tok] for tok in outputs[ind][1:length]])
        # else:
        #     return ' '
        query_list = []
        # print("value of k is " + str(opt.k))
        for i in np.arange(opt.k):
            if eos_tok in outputs[i]:
                length = (outputs[i] == eos_tok).nonzero()[0]
            else:
                length = opt.max_len
            query_list.append(' '.join(
                [TRG.vocab.itos[tok] for tok in outputs[i][1:length]]))
        return query_list, query, query_tokens
Exemplo n.º 6
0
def beam_search(src, model, vocab, opt):

    outputs, e_outputs, log_scores = init_vars(src, model, vocab, opt)

    #print('\nbeam_search start\n')
    #print(f'outputs: {outputs.size()}')
    #print(f'e_outputs: {e_outputs.size()}')
    #print(f'log_scores: {log_scores.size()}')
    eos_tok = vocab.stoi['<eos>']
    src_mask = (src != vocab.stoi['<pad>']).unsqueeze(-2)
    ind = None
    for i in range(2, opt.max_len):

        trg_mask = nopeak_mask(i, opt)

        out = model.out(
            model.decoder(outputs[:, :i],
                          e_outputs,
                          src_mask,
                          trg_mask,
                          policy=True))

        out = F.softmax(out, dim=-1)

        outputs, log_scores = k_best_outputs(outputs, out, log_scores, i,
                                             opt.k)

        ones = (outputs == eos_tok).nonzero(
        )  # Occurrences of end symbols for all input sentences.
        sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).cuda()
        for vec in ones:
            i = vec[0]
            if sentence_lengths[
                    i] == 0:  # First end symbol has not been found yet
                sentence_lengths[i] = vec[1]  # Position of first end symbol

        num_finished_sentences = len([s for s in sentence_lengths if s > 0])

        if num_finished_sentences == opt.k:
            alpha = 0.7
            div = 1 / (sentence_lengths.type_as(log_scores)**alpha)
            _, ind = torch.max(log_scores * div, 1)
            ind = ind.data[0]
            break

    length = []

    if ind is None:
        for output in outputs:
            length.append((output == eos_tok).nonzero()[0])

        result = []
        for idx in range(opt.k):
            result.append(' '.join(
                [vocab.itos[tok] for tok in outputs[idx][1:length[idx]]]))
        #return ' '.join([vocab.itos[tok] for tok in outputs[ind][1:length]])

        return result

    else:

        for output in outputs:
            length.append((output == eos_tok).nonzero()[0])

        result = []
        for idx in range(opt.k):
            result.append(' '.join(
                [vocab.itos[tok] for tok in outputs[idx][1:length[idx]]]))
        #return ' '.join([vocab.itos[tok] for tok in outputs[ind][1:length]])

        return result
Exemplo n.º 7
0
def beam_search(src, model, src_vocab, trg_vocab, opt):

    model.eval()
    outputs, e_outputs, log_scores = init_vars(src, model, src_vocab,
                                               trg_vocab, opt)
    eos_tok = trg_vocab.eos_idx
    src_mask = (src != src_vocab.pad_idx).unsqueeze(-2)
    ind = None
    for i in range(2, opt.max_trg_len):

        trg_mask = nopeak_mask(i, opt)

        out = model.out(
            model.decoder(outputs[:, :i], e_outputs, src_mask, trg_mask))

        outputs, log_scores = k_best_outputs(outputs, out, log_scores, i,
                                             opt.beam_size)
        ones = (outputs == eos_tok).nonzero(
            as_tuple=True
        )  # Occurrences of end symbols for all input sentences.
        x, y = ones
        ones = list(zip(x.detach().cpu().numpy(), y.detach().cpu().numpy()))
        sentence_lengths = torch.zeros(len(outputs),
                                       dtype=torch.long).to(opt.device)

        for vec in ones:
            i = vec[0]
            if sentence_lengths[
                    i] == 0:  # First end symbol has not been found yet
                sentence_lengths[i] = vec[1]  # Position of first end symbol

        num_finished_sentences = len([s for s in sentence_lengths if s > 0])

        if num_finished_sentences == opt.beam_size:
            alpha = 0.7
            div = 1 / (sentence_lengths.type_as(log_scores)**alpha)
            _, ind = torch.max(log_scores * div, 1)
            ind = ind.data[0]
            break

    pad_token = trg_vocab.pad_idx

    if ind is None:
        length = (outputs[0] == eos_tok).nonzero(as_tuple=True)[0]
        outputs = outputs.detach().cpu().numpy()
        try:
            return ' '.join(
                [trg_vocab.itos[tok] for tok in outputs[0][1:length]])
        except:
            return ' '.join([trg_vocab.itos[tok] for tok in outputs[0][1:]])

    else:
        length = (outputs[ind] == eos_tok).nonzero(as_tuple=True)[0]
        outputs = outputs.detach().cpu().numpy()
        try:
            return ' '.join([
                trg_vocab.itos[tok] for tok in outputs[ind][1:length]
                if tok != pad_token and tok != eos_tok
            ])
        except:
            return ' '.join([
                trg_vocab.itos[tok] for tok in outputs[ind][1:]
                if tok != pad_token and tok != eos_tok
            ])
Exemplo n.º 8
0
def beam_search(src, model, SRC, TRG, opt):
    eos_tok = TRG.vocab.stoi['<eos>']
    pad_tok = SRC.vocab.stoi['<pad>']
    ind = None
    if opt.nmt_model_type == 'rnn_naive_model':
        tensor_to_fill_max_len = torch.full((1, opt.max_len - src.shape[1]), pad_tok).to(opt.device)
        src = torch.cat((src, tensor_to_fill_max_len), dim=1)
        outputs, encoder_outputs, log_scores = init_vars(src, model, SRC, TRG, opt)
        encoder_hidden = encoder_outputs
        decoder_hidden = encoder_hidden
    elif opt.nmt_model_type == 'transformer':
        src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
        outputs, e_outputs, log_scores = init_vars(src, model, SRC, TRG, opt) #  [SRC.vocab.itos[i] for i in src.tolist()[0]] to debug

    for i in range(2, opt.max_len): # we already filled init_tok and some of most probable translations
    
        if opt.nmt_model_type == 'transformer': # keep increasing size of sentence
            trg_mask = nopeak_mask(i, opt)
            out = model.out(model.decoder(outputs[:,:i],
                    e_outputs, src_mask, trg_mask)) # [[3, 2], [3, 7, 300], [1, 1, 7], [1, 2, 2]] -> [3, 2, 300] -> [3, 2, 11436]
            out = F.softmax(out, dim=-1)
        elif opt.nmt_model_type == 'rnn_naive_model':
            decoder_input = torch.zeros(opt.k, src.shape[1]).long().to(opt.device) # TODO change to opt.max_len
            decoder_input[:, :i] = outputs[:, :i]
            # OPTION 2 - input a tensor of size src.shape[1] and fill up the other numbers with <unk>
            for j in range(opt.k):
                out_piece, decoder_hidden_piece = model.decoder(decoder_input[j, :], 
                                            decoder_hidden[j, :].unsqueeze(0), encoder_outputs[j, :].unsqueeze(0))
                if j == 0:
                    out = out_piece[:i, :].unsqueeze(0)
                    decoder_hidden_carry = decoder_hidden_piece[:i, :]
                else:
                    out = torch.cat([out, out_piece[:i, :].unsqueeze(0)], dim=0) # final shape: [src.shape[1]*3, vocab_size]
                    decoder_hidden_carry = torch.cat([decoder_hidden_carry, decoder_hidden_piece[:i, :]], dim=0)
            out = F.softmax(out, dim=-1)
            decoder_hidden = decoder_hidden_carry
    
        outputs, log_scores = k_best_outputs(outputs, out, log_scores, i, opt.k, TRG, SRC) # (torch.Size([3, 100]), torch.Size([3, 2, 11436]))
        
        ones = (outputs==eos_tok).nonzero() # Occurrences of end symbols for all input sentences.
        sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).to(opt.device)
        for vec in ones:
            i = vec[0]
            if sentence_lengths[i]==0: # First end symbol has not been found yet
                sentence_lengths[i] = vec[1] # Position of first end symbol

        num_finished_sentences = len([s for s in sentence_lengths if s > 0])

        if num_finished_sentences == opt.k:
            alpha = 0.7
            div = 1/(sentence_lengths.type_as(log_scores)**alpha)
            _, ind = torch.max(log_scores * div, 1)
            ind = ind.data[0]
            break
    
    if ind is None:
        length = (outputs[0]==eos_tok).nonzero()
        if len(length) != 0:
            return ' '.join([TRG.vocab.itos[tok] for tok in outputs[0][1:length[0]]])
        else:
            eos_3_hypothesis = torch.ones([3, 1], dtype=torch.int64).to(opt.device)*eos_tok
            outputs = torch.cat((outputs, eos_3_hypothesis), dim=1)
            return ' '.join([TRG.vocab.itos[tok] for tok in outputs[0][1:]])
    
    else:
        length = (outputs[ind]==eos_tok).nonzero()
        if len(length) != 0:
            return ' '.join([TRG.vocab.itos[tok] for tok in outputs[ind][1:length[0]]])
        else:
            eos_3_hypothesis = torch.ones([3, 1], dtype=torch.int64).to(opt.device)*eos_tok
            outputs = torch.cat((outputs, eos_3_hypothesis), dim=1)
            return ' '.join([TRG.vocab.itos[tok] for tok in outputs[ind][1:]])