class CorrelatedPoissonsLM(GenerativeLM): """ This parameterises an autoregressive product of Poisson distributions, P(x|z) = \prod_{v=1}^V Bern(c_v(x)|b_v(z, x)) where V is the vocabulary size, c_v(x) counts the occurrences of v in x, and b(z,x) \in (0, infty)^V is autoregressive in x (we use a MADE). """ def __init__(self, vocab_size, latent_size, hidden_sizes, pad_idx, num_masks=1, resample_mask_every=0): super().__init__() self.pad_idx = pad_idx self.resample_every = resample_mask_every self.counter = resample_mask_every self.vocab_size = vocab_size self.made_conditioner = MADEConditioner(input_size=vocab_size + latent_size, output_size=vocab_size, context_size=latent_size, hidden_sizes=hidden_sizes, num_masks=num_masks) self.product_of_poissons = AutoregressiveLikelihood( event_size=vocab_size, dist_type=Poisson, conditioner=self.made_conditioner) def make_counts(self, x): """Return a vocab_size-dimensional count-vector view of x""" # We convert ids to V-dimensional one-hot vectors and reduce-sum the time dimension # this gives us word counts # [B, T] -> [B, T, V] -> [B, V] word_counts = F.one_hot(x, self.vocab_size).sum(1) word_counts[:, self. pad_idx] = 0 # we could actually leave it here, it is a way to model length return word_counts.float() def forward(self, x, z, state=dict()) -> Poisson: """ Return Poisson distributions c_v(X)|z, \Sigma_{<v} ~ Poisson(b_v(z, \Sigma_{<v})) with shape [B, Vx] where Vx = |\Sigma| and \Sigma is the vocabulary. """ # We convert ids to V-dimensional one-hot vectors and sum the time dimension # this gives us word counts # [B, V] counts = self.make_counts(x) if self.resample_every > 0: self.counter = self.counter - 1 if self.counter > 0 else self.resample_every return self.product_of_poissons(z, history=counts, resample_mask=self.resample_every > 0 and self.counter == 0) def log_prob(self, likelihood: Poisson, x): # [B, V] counts = self.make_counts(x) # [B, V] -> [B] return likelihood.log_prob(counts).sum(-1) def sample(self, z, max_len=None, greedy=False, state=dict()): """ Sample from X|z where z [B, Dz] """ shape = [z.size(0), self.product_of_poissons.event_size] if greedy: raise NotImplementedError( "Greedy decoding not implemented for MADE") x = self.product_of_poissons.sample( z, torch.zeros(shape, dtype=z.dtype, device=z.device)) return x
class CorrelatedBernoullisLM(GenerativeLM): """ This parameterises an autoregressive product of Bernoulli distributions, P(x|z) = \prod_{v=1}^V Bern([v in x]|b_v(z, x)) where V is the vocabulary size and b(z,x) \in (0, 1)^V is autoregressive in x (we use a MADE). """ def __init__(self, vocab_size, latent_size, hidden_sizes, pad_idx, num_masks=1, resample_mask_every=0): super().__init__() self.pad_idx = pad_idx self.resample_every = resample_mask_every self.counter = resample_mask_every self.vocab_size = vocab_size self.made_conditioner = MADEConditioner(input_size=vocab_size + latent_size, output_size=vocab_size, context_size=latent_size, hidden_sizes=hidden_sizes, num_masks=num_masks) self.product_of_bernoullis = AutoregressiveLikelihood( event_size=vocab_size, dist_type=Bernoulli, conditioner=self.made_conditioner) def make_indicators(self, x): """Return a vocab_size-dimensional bit-vector view of x""" # We convert ids to V-dimensional one-hot vectors and reduce-sum the time dimension # this gives us word counts # [B, T] -> [B, T, V] -> [B, V] word_counts = F.one_hot(x, self.vocab_size).sum(1) word_counts[:, self.pad_idx] = 0 indicators = (word_counts > 0).float() return indicators def forward(self, x, z, state=dict()) -> Bernoulli: """ Return Bernoulli distributions [v \in X]|z, \Sigma_{<v} ~ Bernoulli(b_v(z, \Sigma_{<v})) with shape [B, Vx] where Vx = |\Sigma| and \Sigma is the vocabulary. """ # We convert ids to V-dimensional one-hot vectors and sum the time dimension # this gives us word counts # [B, V] indicators = self.make_indicators(x) if self.resample_every > 0: self.counter = self.counter - 1 if self.counter > 0 else self.resample_every return self.product_of_bernoullis(z, history=indicators, resample_mask=self.resample_every > 0 and self.counter == 0) def log_prob(self, likelihood: Bernoulli, x): # [B, V] indicators = self.make_indicators(x) # [B, V] -> [B] return likelihood.log_prob(indicators).sum(-1) def sample(self, z, max_len=None, greedy=False, state=dict()): """ Sample from X|z where z [B, Dz] """ shape = [z.size(0), self.product_of_bernoullis.event_size] if greedy: raise NotImplementedError( "Greedy decoding not implemented for MADE") x = self.product_of_bernoullis.sample( z, torch.zeros(shape, dtype=z.dtype, device=z.device)) return x