class TrainingWrapper(nn.Module): def __init__(self, net, ignore_index = -100, pad_value = 0): super().__init__() assert isinstance(net, ReformerLM), 'generative trainer wrapper can only accept ReformerLM class' self.pad_value = pad_value self.ignore_index = ignore_index self.net = Autopadder(net) self.max_seq_len = net.max_seq_len @torch.no_grad() def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs): was_training = self.net.training num_dims = len(start_tokens.shape) if num_dims == 1: start_tokens = start_tokens[None, :] b, t = start_tokens.shape self.net.eval() out = start_tokens for _ in range(seq_len): x = out[:, -self.max_seq_len:] logits = self.net(x, **kwargs)[:, -1, :] filtered_logits = filter_logits_fn(logits, thres = filter_thres) probs = F.softmax(filtered_logits / temperature, dim=-1) sample = torch.multinomial(probs, 1) out = torch.cat((out, sample), dim=-1) if eos_token is not None and (sample == eos_token).all(): break out = out[:, t:] if num_dims == 1: out = out.squeeze(0) self.net.train(was_training) return out def forward(self, x, return_loss = False, **kwargs): pad = partial(pad_sequence, batch_first = True, padding_value = self.pad_value) if not return_loss: if not isinstance(x, torch.Tensor): x = pad(x) return self.net(x, **kwargs) if isinstance(x, torch.Tensor): xi = x[:, :-1] xo = x[:, 1:] else: xi = pad(list(map(lambda t: t[:-1], x))) xo = pad(list(map(lambda t: t[1:], x))) out = self.net(xi, **kwargs) loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index) return loss
def __init__(self, net, ignore_index = -100, pad_value = 0): super().__init__() assert isinstance(net, ReformerLM), 'generative trainer wrapper can only accept ReformerLM class' self.pad_value = pad_value self.ignore_index = ignore_index self.net = Autopadder(net) self.max_seq_len = net.max_seq_len