def backward_decode(model, embed, src, src_mask, max_len, start_symbol=2, raw=False, return_term=-1): if raw == False: memory = model.encode(embed(src.to(device)), src_mask) else: memory = model.encode(src.to(device), src_mask) ys = torch.ones(src.shape[0], 1, dtype=torch.int64).fill_(start_symbol).to(device) ret_back = embed(ys).float() if return_term == 2: ppls = 0 for i in range(max_len + 2 - 1): out = model.decode( memory, src_mask, embed(Variable(ys)), Variable(subsequent_mask(ys.size(1)).type_as(src.data))) prob = model.generator.scaled_forward(out[:, -1], scale=10.0) if return_term == 2: ppls += perplexity( model.generator.scaled_forward(out[:, -1], scale=1.0)) back = torch.matmul(prob, embed.weight.data.float()) _, next_word = torch.max(prob, dim=1) ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) ret_back = torch.cat([ret_back, back.unsqueeze(1)], dim=1) return ( ret_back, ys ) if return_term == -1 else ret_back if return_term == 0 else ys if return_term == 1 else ( ret_back, ppls) if return_term == 2 else None
def prob_backward(model, embed, src, src_mask, max_len, start_symbol=2, raw=False): if raw == False: memory = model.encode(embed(src.to(device)), src_mask) else: memory = model.encode(src.to(device), src_mask) ys = torch.ones(src.shape[0], 1, dtype=torch.int64).fill_(start_symbol).to(device) probs = [] for i in range(max_len + 2 - 1): out = model.decode( memory, src_mask, embed(Variable(ys)), Variable(subsequent_mask(ys.size(1)).type_as(src.data))) prob = model.generator(out[:, -1]) probs.append(prob.unsqueeze(1)) _, next_word = torch.max(prob, dim=1) ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) ret = torch.cat(probs, dim=1) return ret
def reconstruct(model, src, max_len, start_symbol=2): memory = model.encoder(model.src_embed[1](src), None) ys = torch.ones(src.shape[0], 1).fill_(start_symbol).long().to(device) ret_back = model.tgt_embed[0].pure_emb(ys).float() for i in range(max_len - 1): out = model.decode( memory, None, Variable(ys), Variable(subsequent_mask(ys.size(1)).type_as(src.data))) prob = model.generator(out[:, -1]) back = torch.matmul(prob, model.tgt_embed[0].lut.weight.data.float()) _, next_word = torch.max(prob, dim=1) ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) ret_back = torch.cat([ret_back, back.unsqueeze(1)], dim=1) return ret_back