Esempio n. 1
0
def generate(strategy, encoder_input, models, tgt_dict, length_beam_size, gold_target_len):
    assert len(models) == 1
    model = models[0]
    
    src_tokens = encoder_input['src_tokens']
    src_tokens = src_tokens.new(src_tokens.tolist())
    bsz = src_tokens.size(0)
    
    encoder_out = model.encoder(**encoder_input)
    beam = predict_length_beam(gold_target_len, encoder_out['predicted_lengths'], length_beam_size)
    
    max_len = beam.max().item()
    length_mask = torch.triu(src_tokens.new(max_len, max_len).fill_(1).long(), 1)
    length_mask = torch.stack([length_mask[beam[batch] - 1] for batch in range(bsz)], dim=0)
    tgt_tokens = src_tokens.new(bsz, length_beam_size, max_len).fill_(tgt_dict.mask())
    tgt_tokens = (1 - length_mask) * tgt_tokens + length_mask * tgt_dict.pad()
    tgt_tokens = tgt_tokens.view(bsz * length_beam_size, max_len)
    
    duplicate_encoder_out(encoder_out, bsz, length_beam_size)
    hypotheses, lprobs = strategy.generate(model, encoder_out, tgt_tokens, tgt_dict)
    
    hypotheses = hypotheses.view(bsz, length_beam_size, max_len)
    lprobs = lprobs.view(bsz, length_beam_size)
    tgt_lengths = (1 - length_mask).sum(-1)
    avg_log_prob = lprobs / tgt_lengths.float()
    best_lengths = avg_log_prob.max(-1)[1]
    hypotheses = torch.stack([hypotheses[b, l, :] for b, l in enumerate(best_lengths)], dim=0)

    return hypotheses
Esempio n. 2
0
def generate(strategy, encoder_input, models, tgt_dict, length_beam_size, gold_target_len):
    assert len(models) == 1
    model = models[0]
    # To do masking for inference. Just respect input masks. Very ad hoc. Fix this later.
    model.decoder.inference = True
    
    src_tokens = encoder_input['src_tokens']
    src_tokens = src_tokens.new(src_tokens.tolist())
    bsz = src_tokens.size(0)
    
    encoder_out = model.encoder(**encoder_input)
    beam = predict_length_beam(gold_target_len, encoder_out['predicted_lengths'], length_beam_size)
    
    max_len = beam.max().item()
    length_mask = torch.triu(src_tokens.new(max_len, max_len).fill_(1).long(), 1)
    length_mask = torch.stack([length_mask[beam[batch] - 1] for batch in range(bsz)], dim=0)
    tgt_tokens = src_tokens.new(bsz, length_beam_size, max_len).fill_(tgt_dict.mask())
    tgt_tokens = (1 - length_mask) * tgt_tokens + length_mask * tgt_dict.pad()
    if strategy.move_eos:
        tgt_tokens[:, :, 0] = tgt_dict.eos()
    eos_idxes = max_len - length_mask.sum(dim=-1) - 1
    tgt_tokens = tgt_tokens.scatter_(-1, eos_idxes.unsqueeze(-1), tgt_dict.eos())
    tgt_tokens = tgt_tokens.view(bsz * length_beam_size, max_len)
    
    duplicate_encoder_out(encoder_out, bsz, length_beam_size)
    model.length_beam_size = length_beam_size
    hypotheses, lprobs = strategy.generate(model, encoder_out, tgt_tokens, tgt_dict)
    
    hypotheses = hypotheses.view(bsz, length_beam_size, max_len)
    lprobs = lprobs.view(bsz, length_beam_size)
    tgt_lengths = (1 - length_mask).sum(-1)
    avg_log_prob = lprobs / tgt_lengths.float()
    best_lengths = avg_log_prob.max(-1)[1]
    hypotheses = torch.stack([hypotheses[b, l, :] for b, l in enumerate(best_lengths)], dim=0)
    
    return hypotheses