Example #1
0
 def decode_batch(self, src: torch.Tensor) -> List[str]:
     batch_size = src.size(0)
     max_len = src.size(1)
     context = torch.LongTensor([SOS_token] * batch_size).cuda()
     context = context.unsqueeze(1)
     for i in range(max_len):
         context_mask = (context != PAD_token)
         context_mask = context_mask.unsqueeze(-2)
         context_mask = context_mask & \
             subsequent_mask(src.size(-1))[:, :, :i+1].type_as(context_mask)
         decoder_mask = (src != PAD_token).unsqueeze(-2)
         decoder_mask = decoder_mask & \
             subsequent_mask(src.size(-1)).type_as(decoder_mask)
         # import pdb; pdb.set_trace()
         decoder_output = self.model(context, src, context_mask,
                                     decoder_mask)
         log_probs = self.model.generator(decoder_output[:, i, :])
         topv, topi = torch.topk(log_probs, 1)
         # topi = topi.squeeze(1)
         context = torch.cat([context, topi], dim=1)
     # decode to string
     decoded_chs = []
     decoded_str = []
     src_length = torch.sum(src != PAD_token, dim=1)
     for i in range(batch_size):
         code = context[i, :]
         for j in range(1, src_length[i].item() + 1):
             decoded_chs.append(self.lang.index2word[code[j].item()])
         else:
             decoded_str.append(' '.join(decoded_chs))
             decoded_chs.clear()
     return decoded_str
Example #2
0
 def decode_batch(self, src: torch.Tensor) -> List[str]:
     batch_size = src.size(0)
     decoder_input = torch.LongTensor([SOS_token] * batch_size).cuda()
     decoder_input = decoder_input.unsqueeze(1)
     src_mask = (src != PAD_token).unsqueeze(-2)
     for i in range(self.max_len):
         tar_mask = (decoder_input != PAD_token)
         tar_mask = tar_mask.unsqueeze(-2)
         tar_mask = tar_mask & \
             subsequent_mask(decoder_input.size(-1)).type_as(tar_mask.data)
         decoder_output = self.model(src, decoder_input, src_mask, tar_mask)
         log_probs = self.model.generator(decoder_output[:, -1, :])
         topv, topi = torch.topk(log_probs, 1)
         # topi = topi.squeeze(1)
         decoder_input = torch.cat([decoder_input, topi], dim=1)
     # decode to string
     decoded_chs = []
     decoded_str = []
     for i in range(batch_size):
         code = decoder_input[i, :]
         for j in range(1, self.max_len):
             if code[j] == EOS_token:
                 decoded_str.append(' '.join(decoded_chs))
                 decoded_chs.clear()
                 break
             decoded_chs.append(self.lang.index2word[code[j].item()])
         else:
             decoded_str.append(' '.join(decoded_chs))
             decoded_chs.clear()
     return decoded_str
Example #3
0
 def make_std_mask(tgt, pad):
     """
     Create a mask to hide padding and future words.
     padd 和 future words 均在mask中用0表示
     """
     tgt_mask = (tgt != pad).unsqueeze(-2)
     tgt_mask = tgt_mask & Variable(
         subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
     return tgt_mask
 def make_std_mask(tgt, max_len):
     "Create a mask to hide padding and future words."
     length = [min(max_len, len(x))+1 for x in tgt]
     tgt_mask = torch.zeros((len(length), max_len + 1))
     for i,j in enumerate(length):
         tgt_mask[i,range(j)]=1
     # tgt_mask = (tgt != pad).unsqueeze(-2)
     tgt_mask = tgt_mask & Variable(
         subsequent_mask(max_len + 1).type_as(tgt_mask.data))
     return tgt_mask
Example #5
0
 def fit_batch(self, batch) -> float:
     decoder_input = batch['src']
     tar = batch['tar_oneway']
     sos = torch.LongTensor([SOS_token] * tar.size(0)).unsqueeze(1).cuda()
     context = torch.cat([sos, tar[:, :-1]], dim=1)
     context_mask = (context != PAD_token).unsqueeze(-2)
     context_mask = context_mask & \
         subsequent_mask(decoder_input.size(-1)).type_as(context_mask)
     decoder_mask = (decoder_input != PAD_token).unsqueeze(-2)
     decoder_mask = decoder_mask & \
         subsequent_mask(decoder_input.size(-1)).type_as(decoder_mask)
     # import pdb; pdb.set_trace()
     decoder_output = self.model(context, decoder_input, context_mask,
                                 decoder_mask)
     log_probs = self.model.generator(decoder_output)
     length = list(map(lambda x: x - 1, batch['tar_length']))  # EOS
     loss = masked_cross_entropy(log_probs.contiguous(), tar, length)
     loss.backward()
     self.opt.step()
     self.opt.zero_grad()
     return loss
Example #6
0
 def generate_response(self, src: str, wait=10) -> str:
     import time
     t0 = time.time()
     length = len(src)
     src_text = src
     src = self.lang.words2indices(src)
     src = torch.tensor(src).unsqueeze(0).cuda()
     valid = False
     success = True
     decoded_str = []
     while not valid:
         if time.time() - t0 > wait:
             success = False
             break
         context = torch.tensor([[SOS_token]]).cuda()
         for i in range(length):
             context_mask = subsequent_mask(
                 src.size(-1))[:, :, :i + 1].type(torch.bool).cuda()
             decoder_mask = subsequent_mask(length).type(torch.bool).cuda()
             decoder_output = self.model(context, src, context_mask,
                                         decoder_mask)
             log_probs = self.model.generator(decoder_output[:, i, :])
             probs = torch.exp(log_probs)
             sampler = torch.distributions.categorical.Categorical(
                 probs=probs.squeeze())
             index = sampler.sample().item()
             context = torch.cat(
                 [context, torch.tensor([[index]]).cuda()], dim=1)
         decoded_str = []
         # import pdb; pdb.set_trace()
         for j in range(1, context.size(1)):
             decoded_str.append(self.lang.index2word[context[0, j].item()])
             if (src_text[j-1] == ',' or decoded_str[-1] == ',') \
                     and src_text[j-1] != decoded_str[-1]:
                 break
         else:
             valid = True
     if not success: decoded_str = []
     return ''.join(decoded_str), success
Example #7
0
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    # ys代表目前已生成的序列,最初为仅包含一个起始符的序列,不断将预测结果追加到序列最后
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len - 1):
        out = model.decode(
            memory, src_mask, Variable(ys),
            Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat(
            [ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys
Example #8
0
 def fit_batch(self, batch) -> float:
     src = batch['src']
     tar = batch['tar']
     decoder_input = tar[:, :-1]
     tar = tar[:, 1:]
     src_mask = (src != PAD_token).unsqueeze(-2)
     tar_mask = (decoder_input != PAD_token).unsqueeze(-2)
     tar_mask = tar_mask & \
         subsequent_mask(decoder_input.size(-1)).type_as(tar_mask.data)
     decoder_output = self.model(src, decoder_input, src_mask, tar_mask)
     log_probs = self.model.generator(decoder_output)
     length = batch['tar_length']
     loss = masked_cross_entropy(log_probs.contiguous(), tar, length)
     loss.backward()
     self.opt.step()
     self.opt.zero_grad()
     return loss
Example #9
0
 def make_std_mask(tgt, pad):
     "Create a mask to hide padding and future words."
     tgt_mask = (tgt != pad).unsqueeze(-2)
     tgt_mask = tgt_mask & Variable(
         transformer.subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
     return tgt_mask