Ejemplo n.º 1
0
    def perplexity(self, x, use_c_prior=False, max_ppl=100):
        """Calculating ppl score for input sequence x

        :param x: (n_batch, n_len) of longs, input sentence x
        :return: n_batch of floats, ppl scores
        """

        # Input check
        assert_check(x, [-1, self.n_len], torch.long)

        # Eval mode
        self.eval()

        # Encoder: x -> z, kl_loss
        z, _ = self.forward_encoder(x)

        # Code: x -> c
        if use_c_prior:
            c = self.sample_c_prior(x.size(0))
        else:
            c = F.softmax(self.forward_discriminator(x), dim=1)

        # Decoder
        h_init = torch.cat([z.unsqueeze(0), c.unsqueeze(0)],
                           2)  # (1, n_batch, d_z + d_c)
        x_emb = self.x_emb(x.t())  # (n_len, n_batch, d_emb)
        x_emb = torch.cat([x_emb, h_init.repeat(x_emb.shape[0], -1, -1)],
                          2)  # (n_len, n_batch, d_emb + d_z + d_c)
        outputs, _ = self.decoder_rnn(
            x_emb, h_init.repeat(self.decoder_n_layers, -1,
                                 -1))  # (n_len, n_batch, d_z + d_c)
        if self.attention:
            outputs = self.decoder_a(outputs)
        n_len, n_batch, _ = outputs.shape  # (n_len, n_batch)
        y = self.decoder_fc(outputs.view(n_len * n_batch, -1)).view(
            n_len, n_batch, -1)  # (n_len, n_batch, n_vocab)
        y = F.softmax(y, dim=2)

        # Calc ppl
        y = y[:-1]  # (n_len - 1, n_batch, n_vocab)
        rx = x.t()[1:].unsqueeze(2)  # (n_len - 1, n_batch, 1)
        ppl = y.gather(2, rx).squeeze(2)
        ppl = ppl.t()
        scores = []
        for i, xl in enumerate(x):
            ppl[i, xl[1:] == self.unk] = 1  # lil hack
            threes = (xl == self.eos).nonzero()
            m = self.n_len - 1 if not len(threes) else threes.max().item()
            scores.append(ppl[i, :m].log().sum().exp()**(-1.0 / (m + 1)))
        ppl = torch.tensor(scores, device=x.device)
        ppl[ppl == float('inf')] = max_ppl  # lil hack

        # Train mode
        self.train()

        # Output check
        assert_check(ppl, [x.size(0)], torch.float, x.device)

        return ppl
Ejemplo n.º 2
0
    def sample_soft_embed(self, n_batch=1, z=None, c=None, temp=1.0):
        """Generating single soft sample x
        TODO: Not working right now

        :param z: (n_batch, d_z) of floats, latent vector z / None
        :param c: (n_batch, d_c) of floats, code c / None
        :param temp: temperature of softmax
        :param device: device to run
        :return: (n_len, d_emb) of floats, sampled soft x
        """

        # Input check
        assert isinstance(n_batch, int) and n_batch > 0
        if z is not None:
            assert_check(z, [n_batch, self.d_z], torch.float)
        if c is not None:
            assert_check(c, [n_batch, self.d_c], torch.float)
        assert isinstance(temp, float) and 0 < temp <= 1

        # Initial values
        device = self.x_emb.weight.device
        if z is None:
            z = self.sample_z_prior(n_batch)  # (n_batch, d_z)
        if c is None:
            c = self.sample_c_prior(n_batch)  # (n_batch, d_c)
        z, c = z.to(device), c.to(device)  # device change
        z1, c1 = z.unsqueeze(0), c.unsqueeze(0)  # +1
        h = torch.cat([z1, c1],
                      dim=2).expand(self.decoder_rnn.depth, -1,
                                    -1)  # (n_layers, n_batch, d_z + d_c)
        emb = self.x_emb(
            torch.tensor(self.bos,
                         device=device).expand(n_batch))  # (n_batch, d_emb)

        # Cycle, word by word
        outputs = [emb]
        for _ in range(1, self.n_len):
            # Init
            x_emb = emb.unsqueeze(0)
            x_emb = torch.cat([x_emb, z1, c1], 2)  # (1, 1, d_emb + d_z + d_c)

            # Step
            o, h = self.decoder_rnn(x_emb, h)
            y = self.decoder_fc(o).squeeze(0)
            y = F.softmax(y / temp, dim=1)

            emb = y @ self.x_emb.weight  # (n_batch, d_emb)
            outputs.append(emb)

        # Making x
        x = torch.stack(outputs, dim=1)  # (n_batch, n_len, d_emb)

        # Output check
        assert_check(z, [n_batch, self.d_z], torch.float, device)
        assert_check(c, [n_batch, self.d_c], torch.float, device)
        assert_check(x, [n_batch, self.n_len, self.d_emb], torch.float, device)

        return z, c, x
Ejemplo n.º 3
0
    def sample_z_prior(self, n_batch):
        """Sampling z ~ p(z) = N(0, I)

        :param n_batch: number of batches
        :return: (n_batch, d_z) of floats, sample of latent z
        """

        # Input check
        assert isinstance(n_batch, int) and n_batch > 0

        # Sampling
        device = self.x_emb.weight.device
        z = torch.randn((n_batch, self.d_z), device=device)  # (n_batch, d_z)

        # Output check
        assert_check(z, [n_batch, self.d_z], torch.float, device)

        return z
Ejemplo n.º 4
0
    def forward(self, x, use_c_prior=False):
        """Do the VAE forward step with prior c

        :param x: (n_batch, n_len) of longs, input sentence x
        :return: float, kl term component of loss
        :return: float, recon component of loss
        """

        # Input check
        assert_check(x, [-1, self.n_len], torch.long)

        # Encoder: x -> z, kl_loss
        z, kl_loss = self.forward_encoder(x)

        # Code: x -> c
        if use_c_prior:
            c = self.sample_c_prior(x.size(0))
        else:
            c = F.softmax(self.forward_discriminator(x), dim=1)

        # Decoder: x, z, c -> recon_loss
        recon_loss = self.forward_decoder(x, z, c)

        # Output check
        assert_check(kl_loss, [], torch.float, x.device)
        assert_check(recon_loss, [], torch.float, kl_loss.device)

        return kl_loss, recon_loss
Ejemplo n.º 5
0
    def forward_discriminator(self, x, do_emb=True):
        """Discriminator step, emulating weights for c ~ D(x)

        :param x: (n_batch, n_len) of longs or (n_batch, n_len, d_emb) of
        floats, input sentence x
        :param do_emb: whether do embedding for x or not
        :return: (n_batch, d_c) of floats, sample of code c
        """

        # Input check
        if do_emb:
            assert_check(x, [-1, self.n_len], torch.long)
        else:
            assert_check(x, [-1, self.n_len, self.d_emb], torch.float)

        # Emb (n_batch, n_len, d_emb)
        if do_emb:
            x_emb = self.x_emb(x)
        else:
            x_emb = x

        # CNN
        c = self.disc_cnn(x)

        # Output check
        assert_check(c, [x.size(0), self.d_c], torch.float, x.device)

        return c
Ejemplo n.º 6
0
    def forward_decoder(self, x, z, c):
        """Decoder step, emulating x ~ G(z, c)

        :param x: (n_batch, n_len) of longs, input sentence x
        :param z: (n_batch, d_z) of floats, latent vector z
        :param c: (n_batch, d_c) of floats, code c
        :return: float, recon component of loss
        """

        # Input check
        assert_check(x, [-1, self.n_len], torch.long)
        assert_check(z, [-1, self.d_z], torch.float, x.device)
        assert_check(c, [-1, self.d_c], torch.float, z.device)

        # Init
        h_init = torch.cat([z, c], 1).unsqueeze(0)  # (1, n_batch, d_z + d_c)

        # Inputs
        x_drop = self.word_dropout(x)  # (n_batch, n_len)
        x_emb = self.x_emb(x_drop.t())  # (n_len, n_batch, d_emb)
        x_emb = torch.cat([x_emb, h_init.expand(x_emb.shape[0], -1, -1)],
                          2)  # (n_len, n_batch, d_emb + d_z + d_c)

        # Rnn step
        outputs, _ = self.decoder_rnn(
            x_emb, h_init.repeat(self.decoder_n_layers, 1,
                                 1))  # (n_len, n_batch, d_z + d_c)

        # Attention
        if self.attention:
            outputs = self.decoder_a(outputs)

        # FC to vocab
        n_len, n_batch, _ = outputs.shape  # (n_len, n_batch)
        y = self.decoder_fc(outputs.view(n_len * n_batch, -1)).view(
            n_len, n_batch, -1)  # (n_len, n_batch, n_vocab)

        # Loss
        recon_loss = F.cross_entropy(y[:-1].view(-1, y.size(2)),
                                     x.t()[1:].contiguous().view(-1))
        # recon_loss = F.cross_entropy(
        #     y.view(-1, y.size(2)),
        #     F.pad(x.t()[1:], (0, 0, 0, 1), 'constant', self.pad).view(-1)
        # )  # 0

        # Output check
        assert_check(recon_loss, [], torch.float, x.device)

        return recon_loss
Ejemplo n.º 7
0
    def sample_c_prior(self, n_batch):
        """Sampling prior, emulating c ~ P(c)

        :param n_batch: number of batches
        :return: (n_batch, d_c) of floats, sample of code c
        """

        # Input check
        assert isinstance(n_batch, int) and n_batch > 0

        # Sampling
        device = self.x_emb.weight.device
        inds = torch.multinomial(
            torch.ones(self.d_c, dtype=torch.float, device=device) / self.d_c,
            n_batch,
            replacement=True)
        ones = torch.eye(self.d_c, device=device)
        c = ones.index_select(0, inds)

        # Output check
        assert_check(c, [n_batch, self.d_c], torch.float, device)

        return c
Ejemplo n.º 8
0
    def word_dropout(self, x):
        """
        Do word dropout: with prob `self.p_word_dropout`, set the word to
        `self.unk`, as initial Bowman et al. (2014) paper proposed.

        :param x: (n_batch, n_len) of longs, input sentence x
        :return: (n_batch, n_len) of longs, x with drops
        """

        # Input check
        assert_check(x, [-1, self.n_len], torch.long)

        # Apply dropout mask
        mask = x.new_tensor(np.random.binomial(n=1,
                                               p=self.p_word_dropout,
                                               size=tuple(x.shape)),
                            dtype=torch.uint8)
        x_drop = x.clone()
        x_drop[mask] = self.unk

        # Output check
        assert_check(x_drop, [x.size(0), self.n_len], torch.long, x.device)

        return x_drop
Ejemplo n.º 9
0
    def forward_encoder(self, x, do_emb=True):
        """Encoder step, emulating z ~ E(x) = q_E(z|x)

        :param x: (n_batch, n_len) of longs or (n_batch, n_len, d_emb) of
        floats, input sentence x
        :param do_emb: whether do embedding for x or not
        :return: (n_batch, d_z) of floats, sample of latent vector z
        :return: float, kl term component of loss
        """

        # Input check
        if do_emb:
            assert_check(x, [-1, self.n_len], torch.long)
        else:
            assert_check(x, [-1, self.n_len, self.d_emb], torch.float)

        # Emb (n_batch, n_len, d_emb)
        if do_emb:
            x_emb = self.x_emb(x)
        else:
            x_emb = x

        # RNN
        x_emb = F.dropout(x_emb)
        h, _ = self.encoder_rnn(x_emb.t(), None)  # (n_len, n_batch, d_h)

        # Forward to latent
        h = h[-1]  # (n_batch, d_h)
        mu, logvar = self.q_mu(h), self.q_logvar(h)  # (n_batch, d_z)

        # Reparameterization trick: z = mu + std * eps; eps ~ N(0, I)
        eps = torch.randn_like(mu)
        z = mu + torch.exp(logvar / 2) * eps

        # KL term loss
        kl_loss = 0.5 * (logvar.exp() + mu**2 - 1 - logvar).sum(1).mean()  # 0

        # Output check
        assert_check(z, [x.size(0), self.d_z], torch.float, x.device)
        assert_check(kl_loss, [], torch.float, z.device)

        return z, kl_loss
Ejemplo n.º 10
0
    def sample_sentence(self,
                        n_batch=1,
                        z=None,
                        c=None,
                        temp=1.0,
                        pad=True,
                        n_beam=5,
                        coverage_penalty=True):
        """Generating n_batch sentences in eval mode with values (could be
        not on same device)

        :param n_batch: number of sentences to generate
        :param z: (n_batch, d_z) of floats, latent vector z or None
        :param c: (n_batch, d_c) of floats, code c or None
        :param temp: temperature of softmax
        :param pad: if do padding to n_len
        :param n_beam: size of beam search
        :param coverage_penalty:
        :return: tuple of four:
            1. (n_batch, d_z) of floats, latent vector z
            2. (n_batch, d_c) of floats, code c
            3. (n_batch, n_len) of longs if pad and list of n_batch len of
            tensors of longs: generated sents word ids if not
            4. (n_batch, n_len, n_len) of floats, attention probs
        """

        # Input check
        assert isinstance(n_batch, int) and n_batch > 0
        if z is not None:
            assert_check(z, [n_batch, self.d_z], torch.float)
        if c is not None:
            assert_check(c, [n_batch, self.d_c], torch.float)
        assert isinstance(temp, float)
        assert isinstance(n_beam, int) and n_beam > 0
        assert isinstance(coverage_penalty, bool)

        # Enable eval mode
        self.eval()

        # Params
        device = self.x_emb.weight.device
        n = n_batch * n_beam

        # `z` and `c`, and then `h0`
        if z is None:
            z = self.sample_z_prior(n_batch)  # (n_batch, d_z)
        if c is None:
            c = self.sample_c_prior(n_batch)  # (n_batch, d_c)
        z, c = z.to(device), c.to(device)  # device change
        h0 = torch.cat([z.unsqueeze(0), c.unsqueeze(0)], dim=2)
        # (1, n, d_z + d_c)
        h0 = h0.unsqueeze(2).repeat(1, 1, n_beam, 1).view(1, n, -1)

        # Initial values
        w = torch.tensor(self.bos, device=device).repeat(n)  # n
        h = h0.repeat(self.decoder_n_layers, 1, 1)  # (n_layers, n, d_z + d_c)
        # Previous context vectors
        context = torch.empty(n,
                              self.n_len,
                              self.d_z + self.d_c,
                              device=device)
        # Attention matrix
        a = torch.zeros(n, self.n_len, self.n_len, device=device)
        # Candidates score for beam search
        H = torch.zeros(n, device=device)
        # X values
        x = torch.tensor(self.pad, device=device).repeat(n, self.n_len)
        x[:, 0] = self.bos
        eos_mask = torch.zeros(n, dtype=torch.uint8, device=device)
        end_pads = torch.tensor(self.n_len, device=device).repeat(n)

        # Cycle, word by word
        for i in range(1, self.n_len):
            # Init
            x_emb = self.x_emb(w)  # (n, d_emb)
            x_emb = x_emb.unsqueeze(0)  # (1, n, d_emb)
            x_emb = torch.cat([x_emb, h0], 2)  # (1, n, d_emb + d_z + d_c)

            # Step
            o, h = self.decoder_rnn(x_emb, h)  # o: (1, n, d_z + d_c)
            o = o.squeeze(0)
            if self.attention:
                context[:, i - 1, :] = o
                # o: (n, d_z + d_c), aw: (n, i)
                o, aw = self.decoder_a.forward_inference(
                    o, None if i == 1 else context[:, :i - 1, :])
                a[~eos_mask, i, :i] = aw[~eos_mask]
            y = F.softmax(self.decoder_fc(o) / temp, dim=-1)  # (n, n_vocab)

            # Generating
            nw = torch.multinomial(y, n_beam)  # (n, n_beam)
            pc = y.gather(1, nw)  # (n, n_beam)
            # (n_batch, n_beam, n_beam)
            pc = pc.view(n_batch, n_beam, -1).log()
            pc = H.view(n_batch, -1, 1) + pc  # (n_batch, n_beam, n_beam)
            aH, u = pc.view(n_batch, -1).topk(n_beam, 1)  # (n_batch, n_beam)
            w = nw.view(n_batch, -1).gather(1, u).view(-1)  # n

            # Masking new candidates
            parents = u.div(n_beam)  # (n_batch, n_beam)
            base_mask = torch.arange(n_batch, dtype=torch.long, device=device)
            base_mask *= n_beam
            base_mask = base_mask.unsqueeze(1).repeat(1, n_beam).view(-1)
            mask = base_mask + parents.view(-1)
            h = h[:, mask, :]
            context = context[mask]
            a = a[mask]
            H = H[mask]
            H[~eos_mask] += aH.view(-1)[~eos_mask]
            x = x[mask]
            eos_mask = eos_mask[mask]
            end_pads = end_pads[mask]
            x[~eos_mask, i] = w[~eos_mask]

            # Eos masks
            i_eos_mask = (w == self.eos)
            end_pads[i_eos_mask] = i
            eos_mask |= i_eos_mask

        # Choosing best candidate
        if coverage_penalty:
            H += a.sum(1).clamp(max=1).log()
        u = H.view(n_batch, -1).argmax(1).unsqueeze(-1).unsqueeze(-1)
        ux = u.repeat(1, 1, self.n_len)
        x = x.view(n_batch, -1, self.n_len)
        x = x.gather(1, ux).squeeze(1)
        ua = u.unsqueeze(-1).repeat(1, 1, self.n_len, self.n_len)
        a = a.view(n_batch, -1, self.n_len, self.n_len)
        a = a.gather(1, ua).squeeze(1)

        # Pad
        if not pad:
            new_x = []
            for i in range(x.size(0)):
                new_x.append(x[i, :end_pads[i]])
            x = new_x

        # Back to train
        self.train()

        # Output check
        assert_check(z, [n_batch, self.d_z], torch.float, device)
        assert_check(c, [n_batch, self.d_c], torch.float, device)
        if pad:
            assert_check(x, [n_batch, self.n_len], torch.long, device)
        else:
            assert len(x) == n_batch
            for i_x in x:
                assert_check(i_x, [-1], torch.long, device)
                assert len(i_x) <= self.n_len
        assert_check(a, [n_batch, self.n_len, self.n_len], torch.float, device)

        return z, c, x, a