예제 #1
0
    def forward(self, inp, trg):
        """
        Parameters:
        -----------
        inp: torch.Tensor (seq_len x batch), Train data for a single batch.
        trg: torch.Tensor (seq_len x batch), Target output for a single batch.

        Returns: outs, att_weights
        --------
        outs: torch.Tensor (seq_len x batch x hid_dim)
        att_weights: (batch x seq_len x target_len)
        """
        # encoder
        inp = word_dropout(
            inp, self.target_code, reserved_codes=self.reserved_codes,
            p=self.word_dropout, training=self.training)
        emb_inp = self.src_embeddings(inp)
        enc_outs, enc_hidden = self.encoder(emb_inp)
        # decoder
        dec_hidden = self.decoder.init_hidden_for(enc_hidden)
        dec_outs, dec_out, enc_att = [], None, None
        if self.decoder.att_type == 'Bahdanau':
            # cache encoder att projection for bahdanau
            enc_att = self.decoder.attn.project_enc_outs(enc_outs)
        for prev in trg.chunk(trg.size(0)):
            emb_prev = self.trg_embeddings(prev).squeeze(0)
            dec_out, dec_hidden, att_weight = self.decoder(
                emb_prev, dec_hidden, enc_outs, out=dec_out, enc_att=enc_att)
            dec_outs.append(dec_out)
        return torch.stack(dec_outs)
예제 #2
0
    def forward(self, inp, trg):
        """
        Parameters:
        -----------
        inp: torch.Tensor (seq_len x batch), Train data for a single batch.
        trg: torch.Tensor (seq_len x batch), Target output for a single batch.

        Returns: outs, att_weights
        --------
        outs: torch.Tensor (seq_len x batch x hid_dim)
        att_weights: (batch x seq_len x target_len)
        """
        # encoder
        inp = word_dropout(inp,
                           self.target_code,
                           reserved_codes=self.reserved_codes,
                           p=self.word_dropout,
                           training=self.training)
        emb_inp = self.src_embeddings(inp)
        enc_outs, enc_hidden = self.encoder(emb_inp)
        # decoder
        dec_hidden = self.decoder.init_hidden_for(enc_hidden)
        dec_outs, dec_out, enc_att = [], None, None
        if self.decoder.att_type == 'Bahdanau':
            # cache encoder att projection for bahdanau
            enc_att = self.decoder.attn.project_enc_outs(enc_outs)
        for prev in trg.chunk(trg.size(0)):
            emb_prev = self.trg_embeddings(prev).squeeze(0)
            dec_out, dec_hidden, att_weight = self.decoder(emb_prev,
                                                           dec_hidden,
                                                           enc_outs,
                                                           out=dec_out,
                                                           enc_att=enc_att)
            dec_outs.append(dec_out)
        return torch.stack(dec_outs)
예제 #3
0
파일: lm.py 프로젝트: burtenshaw/seqmod
    def forward(self, inp, hidden=None, schedule=None, **kwargs):
        """
        Parameters:
        -----------
        inp: torch.Tensor (seq_len x batch_size)

        Returns:
        --------
        outs: torch.Tensor (seq_len * batch_size x vocab)
        hidden: see output of RNN, GRU, LSTM in torch.nn
        weights: None or list of weights (batch_size x seq_len),
            It will only be not None if attention is provided.
        """
        inp = word_dropout(
            inp, self.target_code, p=self.word_dropout,
            reserved_codes=self.reserved_codes, training=self.training)
        emb = self.embeddings(inp)
        if self.has_dropout:
            emb = F.dropout(emb, p=self.dropout, training=self.training)
        outs, hidden = self.rnn(emb, hidden or self.init_hidden_for(emb))
        if self.has_dropout:
            outs = F.dropout(outs, p=self.dropout, training=self.training)
        weights = None
        if self.add_attn:
            outs, weights = self.attn(outs, emb)
        seq_len, batch, hid_dim = outs.size()
        outs = outs.view(seq_len * batch, hid_dim)
        if self.add_deepout:
            outs = self.deepout(outs)
        outs = F.log_softmax(self.project(outs))
        return outs, hidden, weights
예제 #4
0
파일: vae.py 프로젝트: fbkarsdorp/seqmod
    def forward(self, src, trg, labels=None):
        """
        Parameters:
        -----------

        inp: (seq_len x batch). Input batch of sentences to be encoded.
            It is assumed that inp has <bos> and <eos> symbols.
        labels: None or (batch x num_labels). To be used by conditional VAEs.

        Returns:
        --------
        preds: (batch x vocab_size * seq_len)
        mu: (batch x z_dim)
        logvar: (batch x z_dim)
        """
        # - encoder
        emb = self.embeddings(src)
        mu, logvar = self.encoder(emb)
        z = self.encoder.reparametrize(mu, logvar)
        # - decoder
        hidden = self.decoder.init_hidden_for(z)
        dec_outs, z_cond = [], z if self.add_z else None
        # apply word dropout on the conditioning targets
        trg = word_dropout(trg,
                           self.target_code,
                           p=self.word_dropout,
                           reserved_codes=self.reserved_codes,
                           training=self.training)
        for emb_t in self.embeddings(trg).chunk(trg.size(0)):
            # rnn
            dec_out, hidden = self.decoder(emb_t.squeeze(0), hidden, z=z_cond)
            dec_outs.append(dec_out)
        dec_outs = torch.stack(dec_outs)
        return self.project(dec_outs), mu, logvar
예제 #5
0
파일: vae.py 프로젝트: mikekestemont/seqmod
    def forward(self, src, trg, labels=None):
        """
        Parameters:
        -----------

        inp: (seq_len x batch). Input batch of sentences to be encoded.
            It is assumed that inp has <bos> and <eos> symbols.
        labels: None or (batch x num_labels). To be used by conditional VAEs.

        Returns:
        --------
        preds: (batch x vocab_size * seq_len)
        mu: (batch x z_dim)
        logvar: (batch x z_dim)
        """
        # - encoder
        emb = self.embeddings(src)
        mu, logvar = self.encoder(emb)
        z = self.encoder.reparametrize(mu, logvar)
        # - decoder
        hidden = self.decoder.init_hidden_for(z)
        dec_outs, z_cond = [], z if self.add_z else None
        # apply word dropout on the conditioning targets
        trg = word_dropout(
            trg, self.target_code, p=self.word_dropout,
            reserved_codes=self.reserved_codes, training=self.training)
        for emb_t in self.embeddings(trg).chunk(trg.size(0)):
            # rnn
            dec_out, hidden = self.decoder(emb_t.squeeze(0), hidden, z=z_cond)
            dec_outs.append(dec_out)
        dec_outs = torch.stack(dec_outs)
        return self.project(dec_outs), mu, logvar
예제 #6
0
파일: lm.py 프로젝트: mikekestemont/seqmod
    def forward(self, inp, hidden=None, schedule=None, **kwargs):
        """
        Parameters:
        -----------
        inp: torch.Tensor (seq_len x batch_size)

        Returns:
        --------
        outs: torch.Tensor (seq_len * batch_size x vocab)
        hidden: see output of RNN, GRU, LSTM in torch.nn
        weights: None or list of weights (batch_size x seq_len),
            It will only be not None if attention is provided.
        """
        inp = word_dropout(
            inp, self.target_code, p=self.word_dropout,
            reserved_codes=self.reserved_codes, training=self.training)
        emb = self.embeddings(inp)
        if self.has_dropout:
            emb = F.dropout(emb, p=self.dropout, training=self.training)
        outs, hidden = self.rnn(emb, hidden or self.init_hidden_for(emb))
        if self.has_dropout:
            outs = F.dropout(outs, p=self.dropout, training=self.training)
        weights = None
        if self.add_attn:
            outs, weights = self.attn(outs, emb)
        seq_len, batch, hid_dim = outs.size()
        outs = outs.view(seq_len * batch, hid_dim)
        if self.add_deepout:
            outs = self.deepout(outs)
        outs = F.log_softmax(self.project(outs))
        return outs, hidden, weights
예제 #7
0
    def forward(self, inp, trg, conds=None):
        """
        Parameters:
        -----------
        inp: torch.Tensor (seq_len x batch), Train data for a single batch.
        trg: torch.Tensor (seq_len x batch), Target output for a single batch.

        Returns: outs, att_weights
        --------
        dec_outs: torch.Tensor (seq_len x batch x hid_dim)
        """
        if self.cond_dim is not None:
            if conds is None:
                raise ValueError("Conditional decoder needs conds")
            conds = [emb(cond) for cond, emb in zip(conds, self.cond_embs)]
            # (batch_size x total emb dim)
            conds = torch.cat(conds, 1)

        # encoder
        inp = word_dropout(
            inp, self.target_code, reserved_codes=self.reserved_codes,
            p=self.word_dropout, training=self.training)

        enc_outs, enc_hidden = self.encoder(self.src_embeddings(inp))
        cond_out = []
        if self.cond_dim is not None:
            # use last step as summary vector
            # enc_out = grad_reverse(enc_outs[-1]) # keep this for experiments
            # use average step as summary vector
            enc_out = grad_reverse(enc_outs.mean(dim=0))
            for grl in self.grls:
                cond_out.append(F.log_softmax(grl(enc_out)))

        # decoder
        dec_hidden = self.decoder.init_hidden_for(enc_hidden)
        dec_outs, dec_out, enc_att = [], None, None

        if self.decoder.att_type == 'Bahdanau':
            # cache encoder att projection for bahdanau
            enc_att = self.decoder.attn.project_enc_outs(enc_outs)

        for prev in trg:
            # (seq_len x batch x emb_dim)
            prev_emb = self.trg_embeddings(prev)
            # (batch x emb_dim)
            prev_emb = prev_emb.squeeze(0)
            dec_out, dec_hidden, att_weight = self.decoder(
                prev_emb, dec_hidden, enc_outs, enc_att=enc_att,
                prev_out=dec_out, conds=conds)
            dec_outs.append(dec_out)

        return torch.stack(dec_outs), tuple(cond_out)
예제 #8
0
    def forward(self, inp, hidden=None, conds=None, **kwargs):
        """
        Parameters:
        -----------
        inp: torch.Tensor (seq_len x batch_size)
        hidden: None or torch.Tensor (num_layers x batch_size x hid_dim)
        conds: None or tuple of torch.Tensor (seq_len x batch_size) of length
            equal to the number of model conditions. `conditions` are required
            in case of a CLM.

        Returns:
        --------
        outs: torch.Tensor (seq_len * batch_size x vocab)
        hidden: see output of RNN, GRU, LSTM in torch.nn
        weights: None or list of weights (batch_size x seq_len),
            It will only be not None if attention is provided.
        """
        if hasattr(self, 'conds') and self.conds is not None and conds is None:
            raise ValueError("Conditional model expects conditions as input")
        inp = word_dropout(inp,
                           self.target_code,
                           p=self.word_dropout,
                           reserved_codes=self.reserved_codes,
                           training=self.training)
        emb = self.embeddings(inp)
        if conds is not None:
            conds = torch.cat(
                [c_emb(inp_c) for c_emb, inp_c in zip(self.conds, conds)], 2)
            emb = torch.cat([emb, conds], 2)
        if self.has_dropout and not self.cell.startswith('RHN'):
            emb = F.dropout(emb, p=self.dropout, training=self.training)
        hidden = hidden if hidden is not None else self.init_hidden_for(emb)
        outs, hidden = self.rnn(emb, hidden)
        if self.has_dropout:
            outs = F.dropout(outs, p=self.dropout, training=self.training)
        weights = None
        if self.add_attn:
            outs, weights = self.attn(outs, emb)
        seq_len, batch, hid_dim = outs.size()
        outs = outs.view(seq_len * batch, hid_dim)
        if self.add_deepout:
            outs = self.deepout(outs)
        outs = F.log_softmax(self.project(outs))
        return outs, hidden, weights