def decode_batch(self, src: torch.Tensor) -> List[str]: batch_size = src.size(0) max_len = src.size(1) context = torch.LongTensor([SOS_token] * batch_size).cuda() context = context.unsqueeze(1) for i in range(max_len): context_mask = (context != PAD_token) context_mask = context_mask.unsqueeze(-2) context_mask = context_mask & \ subsequent_mask(src.size(-1))[:, :, :i+1].type_as(context_mask) decoder_mask = (src != PAD_token).unsqueeze(-2) decoder_mask = decoder_mask & \ subsequent_mask(src.size(-1)).type_as(decoder_mask) # import pdb; pdb.set_trace() decoder_output = self.model(context, src, context_mask, decoder_mask) log_probs = self.model.generator(decoder_output[:, i, :]) topv, topi = torch.topk(log_probs, 1) # topi = topi.squeeze(1) context = torch.cat([context, topi], dim=1) # decode to string decoded_chs = [] decoded_str = [] src_length = torch.sum(src != PAD_token, dim=1) for i in range(batch_size): code = context[i, :] for j in range(1, src_length[i].item() + 1): decoded_chs.append(self.lang.index2word[code[j].item()]) else: decoded_str.append(' '.join(decoded_chs)) decoded_chs.clear() return decoded_str
def decode_batch(self, src: torch.Tensor) -> List[str]: batch_size = src.size(0) decoder_input = torch.LongTensor([SOS_token] * batch_size).cuda() decoder_input = decoder_input.unsqueeze(1) src_mask = (src != PAD_token).unsqueeze(-2) for i in range(self.max_len): tar_mask = (decoder_input != PAD_token) tar_mask = tar_mask.unsqueeze(-2) tar_mask = tar_mask & \ subsequent_mask(decoder_input.size(-1)).type_as(tar_mask.data) decoder_output = self.model(src, decoder_input, src_mask, tar_mask) log_probs = self.model.generator(decoder_output[:, -1, :]) topv, topi = torch.topk(log_probs, 1) # topi = topi.squeeze(1) decoder_input = torch.cat([decoder_input, topi], dim=1) # decode to string decoded_chs = [] decoded_str = [] for i in range(batch_size): code = decoder_input[i, :] for j in range(1, self.max_len): if code[j] == EOS_token: decoded_str.append(' '.join(decoded_chs)) decoded_chs.clear() break decoded_chs.append(self.lang.index2word[code[j].item()]) else: decoded_str.append(' '.join(decoded_chs)) decoded_chs.clear() return decoded_str
def make_std_mask(tgt, pad): """ Create a mask to hide padding and future words. padd 和 future words 均在mask中用0表示 """ tgt_mask = (tgt != pad).unsqueeze(-2) tgt_mask = tgt_mask & Variable( subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) return tgt_mask
def make_std_mask(tgt, max_len): "Create a mask to hide padding and future words." length = [min(max_len, len(x))+1 for x in tgt] tgt_mask = torch.zeros((len(length), max_len + 1)) for i,j in enumerate(length): tgt_mask[i,range(j)]=1 # tgt_mask = (tgt != pad).unsqueeze(-2) tgt_mask = tgt_mask & Variable( subsequent_mask(max_len + 1).type_as(tgt_mask.data)) return tgt_mask
def fit_batch(self, batch) -> float: decoder_input = batch['src'] tar = batch['tar_oneway'] sos = torch.LongTensor([SOS_token] * tar.size(0)).unsqueeze(1).cuda() context = torch.cat([sos, tar[:, :-1]], dim=1) context_mask = (context != PAD_token).unsqueeze(-2) context_mask = context_mask & \ subsequent_mask(decoder_input.size(-1)).type_as(context_mask) decoder_mask = (decoder_input != PAD_token).unsqueeze(-2) decoder_mask = decoder_mask & \ subsequent_mask(decoder_input.size(-1)).type_as(decoder_mask) # import pdb; pdb.set_trace() decoder_output = self.model(context, decoder_input, context_mask, decoder_mask) log_probs = self.model.generator(decoder_output) length = list(map(lambda x: x - 1, batch['tar_length'])) # EOS loss = masked_cross_entropy(log_probs.contiguous(), tar, length) loss.backward() self.opt.step() self.opt.zero_grad() return loss
def generate_response(self, src: str, wait=10) -> str: import time t0 = time.time() length = len(src) src_text = src src = self.lang.words2indices(src) src = torch.tensor(src).unsqueeze(0).cuda() valid = False success = True decoded_str = [] while not valid: if time.time() - t0 > wait: success = False break context = torch.tensor([[SOS_token]]).cuda() for i in range(length): context_mask = subsequent_mask( src.size(-1))[:, :, :i + 1].type(torch.bool).cuda() decoder_mask = subsequent_mask(length).type(torch.bool).cuda() decoder_output = self.model(context, src, context_mask, decoder_mask) log_probs = self.model.generator(decoder_output[:, i, :]) probs = torch.exp(log_probs) sampler = torch.distributions.categorical.Categorical( probs=probs.squeeze()) index = sampler.sample().item() context = torch.cat( [context, torch.tensor([[index]]).cuda()], dim=1) decoded_str = [] # import pdb; pdb.set_trace() for j in range(1, context.size(1)): decoded_str.append(self.lang.index2word[context[0, j].item()]) if (src_text[j-1] == ',' or decoded_str[-1] == ',') \ and src_text[j-1] != decoded_str[-1]: break else: valid = True if not success: decoded_str = [] return ''.join(decoded_str), success
def greedy_decode(model, src, src_mask, max_len, start_symbol): memory = model.encode(src, src_mask) # ys代表目前已生成的序列,最初为仅包含一个起始符的序列,不断将预测结果追加到序列最后 ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data) for i in range(max_len - 1): out = model.decode( memory, src_mask, Variable(ys), Variable(subsequent_mask(ys.size(1)).type_as(src.data))) prob = model.generator(out[:, -1]) _, next_word = torch.max(prob, dim=1) next_word = next_word.data[0] ys = torch.cat( [ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1) return ys
def fit_batch(self, batch) -> float: src = batch['src'] tar = batch['tar'] decoder_input = tar[:, :-1] tar = tar[:, 1:] src_mask = (src != PAD_token).unsqueeze(-2) tar_mask = (decoder_input != PAD_token).unsqueeze(-2) tar_mask = tar_mask & \ subsequent_mask(decoder_input.size(-1)).type_as(tar_mask.data) decoder_output = self.model(src, decoder_input, src_mask, tar_mask) log_probs = self.model.generator(decoder_output) length = batch['tar_length'] loss = masked_cross_entropy(log_probs.contiguous(), tar, length) loss.backward() self.opt.step() self.opt.zero_grad() return loss
def make_std_mask(tgt, pad): "Create a mask to hide padding and future words." tgt_mask = (tgt != pad).unsqueeze(-2) tgt_mask = tgt_mask & Variable( transformer.subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) return tgt_mask