Пример #1
0
class CorrelatedPoissonsLM(GenerativeLM):
    """
    This parameterises an autoregressive product of Poisson distributions,
        P(x|z) = \prod_{v=1}^V Bern(c_v(x)|b_v(z, x))
    where V is the vocabulary size, c_v(x) counts the occurrences of v in x,
    and b(z,x) \in (0, infty)^V is autoregressive in x (we use a MADE).
    """
    def __init__(self,
                 vocab_size,
                 latent_size,
                 hidden_sizes,
                 pad_idx,
                 num_masks=1,
                 resample_mask_every=0):
        super().__init__()
        self.pad_idx = pad_idx
        self.resample_every = resample_mask_every
        self.counter = resample_mask_every
        self.vocab_size = vocab_size
        self.made_conditioner = MADEConditioner(input_size=vocab_size +
                                                latent_size,
                                                output_size=vocab_size,
                                                context_size=latent_size,
                                                hidden_sizes=hidden_sizes,
                                                num_masks=num_masks)
        self.product_of_poissons = AutoregressiveLikelihood(
            event_size=vocab_size,
            dist_type=Poisson,
            conditioner=self.made_conditioner)

    def make_counts(self, x):
        """Return a vocab_size-dimensional count-vector view of x"""
        # We convert ids to V-dimensional one-hot vectors and reduce-sum the time dimension
        #  this gives us word counts
        # [B, T] -> [B, T, V] -> [B, V]
        word_counts = F.one_hot(x, self.vocab_size).sum(1)
        word_counts[:, self.
                    pad_idx] = 0  # we could actually leave it here, it is a way to model length
        return word_counts.float()

    def forward(self, x, z, state=dict()) -> Poisson:
        """
        Return Poisson distributions 
            c_v(X)|z, \Sigma_{<v} ~ Poisson(b_v(z, \Sigma_{<v}))
        with shape [B, Vx] where Vx = |\Sigma| and \Sigma is the vocabulary.
        """
        # We convert ids to V-dimensional one-hot vectors and sum the time dimension
        #  this gives us word counts
        # [B, V]
        counts = self.make_counts(x)
        if self.resample_every > 0:
            self.counter = self.counter - 1 if self.counter > 0 else self.resample_every
        return self.product_of_poissons(z,
                                        history=counts,
                                        resample_mask=self.resample_every > 0
                                        and self.counter == 0)

    def log_prob(self, likelihood: Poisson, x):
        # [B, V]
        counts = self.make_counts(x)
        # [B, V] -> [B]
        return likelihood.log_prob(counts).sum(-1)

    def sample(self, z, max_len=None, greedy=False, state=dict()):
        """
        Sample from X|z where z [B, Dz]
        """
        shape = [z.size(0), self.product_of_poissons.event_size]
        if greedy:
            raise NotImplementedError(
                "Greedy decoding not implemented for MADE")
        x = self.product_of_poissons.sample(
            z, torch.zeros(shape, dtype=z.dtype, device=z.device))
        return x
Пример #2
0
class CorrelatedBernoullisLM(GenerativeLM):
    """
    This parameterises an autoregressive product of Bernoulli distributions,
        P(x|z) = \prod_{v=1}^V Bern([v in x]|b_v(z, x))
    where V is the vocabulary size and b(z,x) \in (0, 1)^V is autoregressive in x (we use a MADE).
    """
    def __init__(self,
                 vocab_size,
                 latent_size,
                 hidden_sizes,
                 pad_idx,
                 num_masks=1,
                 resample_mask_every=0):
        super().__init__()
        self.pad_idx = pad_idx
        self.resample_every = resample_mask_every
        self.counter = resample_mask_every
        self.vocab_size = vocab_size
        self.made_conditioner = MADEConditioner(input_size=vocab_size +
                                                latent_size,
                                                output_size=vocab_size,
                                                context_size=latent_size,
                                                hidden_sizes=hidden_sizes,
                                                num_masks=num_masks)
        self.product_of_bernoullis = AutoregressiveLikelihood(
            event_size=vocab_size,
            dist_type=Bernoulli,
            conditioner=self.made_conditioner)

    def make_indicators(self, x):
        """Return a vocab_size-dimensional bit-vector view of x"""
        # We convert ids to V-dimensional one-hot vectors and reduce-sum the time dimension
        #  this gives us word counts
        # [B, T] -> [B, T, V] -> [B, V]
        word_counts = F.one_hot(x, self.vocab_size).sum(1)
        word_counts[:, self.pad_idx] = 0
        indicators = (word_counts > 0).float()
        return indicators

    def forward(self, x, z, state=dict()) -> Bernoulli:
        """
        Return Bernoulli distributions 
            [v \in X]|z, \Sigma_{<v} ~ Bernoulli(b_v(z, \Sigma_{<v}))
        with shape [B, Vx] where Vx = |\Sigma| and \Sigma is the vocabulary.
        """
        # We convert ids to V-dimensional one-hot vectors and sum the time dimension
        #  this gives us word counts
        # [B, V]
        indicators = self.make_indicators(x)
        if self.resample_every > 0:
            self.counter = self.counter - 1 if self.counter > 0 else self.resample_every
        return self.product_of_bernoullis(z,
                                          history=indicators,
                                          resample_mask=self.resample_every > 0
                                          and self.counter == 0)

    def log_prob(self, likelihood: Bernoulli, x):
        # [B, V]
        indicators = self.make_indicators(x)
        # [B, V] -> [B]
        return likelihood.log_prob(indicators).sum(-1)

    def sample(self, z, max_len=None, greedy=False, state=dict()):
        """
        Sample from X|z where z [B, Dz]
        """
        shape = [z.size(0), self.product_of_bernoullis.event_size]
        if greedy:
            raise NotImplementedError(
                "Greedy decoding not implemented for MADE")
        x = self.product_of_bernoullis.sample(
            z, torch.zeros(shape, dtype=z.dtype, device=z.device))
        return x