Example #1
0
    def forward_sample(self, sampleSoft, sampleHard, z, c, h):
        if sampleSoft is not None:
            # with sampleSoftIx (mbsize x vocabsize) gradients will pass through
            emb = soft_embed(self.emb, sampleSoft)
        else:
            # with sampleIx (mbsize) indextensor, gradients dont pass through.
            emb = self.emb(sampleHard)
        # mb x (embdim + zdim + cdim)
        emb = torch.cat([emb, z, c], 1)
        # insert seqlen 1 (mbsize x 1 x ezcdim)
        emb = emb.unsqueeze(1)
        # 1 x mbsize x h_dim
        output, h = self.rnn(emb, h)
        #     mbsize x h_dim
        output = output.squeeze(1)

        # apply skip connection
        if self.skip_connetions:
            latent_code = torch.cat([z, c], 1)
            output = self.skip_weight_x(output) + self.skip_weight_z(
                latent_code)

        # [mbsize x self.n_vocab]
        logits = self.fc(output)
        return logits, h
 def forward_encoder(self, inputs):
     '''
     Inputs is batch of sentences: seq_len x mbsize
            or batch of soft sentences: seq_len x mbsize x n_vocab.
     '''
     if inputs.dim() == 2:
         inputs = self.word_emb(inputs)
     else:  # dim == 3.
         inputs = soft_embed(self.word_emb, inputs)
     return self.encoder(inputs)
 def forward_classifier(self, inputs):
     """
     Inputs is batch of sentences: mbsize x seq_len
            or batch of soft sentences: mbsize x seq_len x n_vocab.
     """
     if inputs.dim() == 2:
         inputs = self.word_emb(inputs)
     else:  # dim == 3.
         inputs = soft_embed(self.word_emb, inputs)
     return self.classifier(inputs)