Exemplo n.º 1
0
    def __init__(self, corpus, config):
        super(ContEncoder, self).__init__(config)
        self.vocab = corpus.vocab
        self.rev_vocab = corpus.rev_vocab
        self.vocab_size = len(self.vocab)
        self.go_id = self.rev_vocab[BOS]
        self.eos_id = self.rev_vocab[EOS]
        self.config = config

        self.embedding = nn.Embedding(self.vocab_size, config.embed_size,
                                      padding_idx=self.rev_vocab[PAD])

        self.utt_encoder = RnnUttEncoder(config.utt_cell_size, 0.0,
                                        bidirection=False,
                                         #  bidirection=True in the original code
                                         use_attn=config.utt_type == 'attn_rnn',
                                         vocab_size=self.vocab_size,
                                         embedding=self.embedding)

        self.ctx_encoder = EncoderRNN(self.utt_encoder.output_size,
                                      config.ctx_cell_size,
                                      0.0,
                                      0.0,
                                      config.num_layer,
                                      config.rnn_cell,
                                      variable_lengths=self.config.fix_batch)
Exemplo n.º 2
0
    def __init__(self, corpus, config):
        super(StED, self).__init__(config)
        self.vocab = corpus.vocab
        self.rev_vocab = corpus.rev_vocab
        self.vocab_size = len(self.vocab)
        self.go_id = self.rev_vocab[BOS]
        self.eos_id = self.rev_vocab[EOS]
        if not hasattr(config, "freeze_step"):
            config.freeze_step = 6000

        # build model here
        # word embeddings
        self.x_embedding = nn.Embedding(self.vocab_size, config.embed_size)

        # latent action learned
        self.x_encoder = EncoderRNN(config.embed_size, config.dec_cell_size,
                                    dropout_p=config.dropout,
                                    rnn_cell=config.rnn_cell,
                                    variable_lengths=False)

        self.q_y = nn.Linear(config.dec_cell_size, config.y_size * config.k)
        self.x_init_connector = nn_lib.LinearConnector(config.y_size * config.k,
                                                       config.dec_cell_size,
                                                       config.rnn_cell == 'lstm')
        # decoder
        self.prev_decoder = DecoderRNN(self.vocab_size, config.max_dec_len,
                                        config.embed_size, config.dec_cell_size,
                                        self.go_id, self.eos_id,
                                        n_layers=1, rnn_cell=config.rnn_cell,
                                        input_dropout_p=config.dropout,
                                        dropout_p=config.dropout,
                                        use_attention=False,
                                        use_gpu=config.use_gpu,
                                        embedding=self.x_embedding)

        self.next_decoder = DecoderRNN(self.vocab_size, config.max_dec_len,
                                        config.embed_size, config.dec_cell_size,
                                        self.go_id, self.eos_id,
                                        n_layers=1, rnn_cell=config.rnn_cell,
                                        input_dropout_p=config.dropout,
                                        dropout_p=config.dropout,
                                        use_attention=False,
                                        use_gpu=config.use_gpu,
                                        embedding=self.x_embedding)


        # Encoder-Decoder STARTS here
        self.embedding = nn.Embedding(self.vocab_size, config.embed_size,
                                      padding_idx=self.rev_vocab[PAD])

        self.utt_encoder = RnnUttEncoder(config.utt_cell_size, config.dropout,
                                         bidirection=False,
                                         #  bidirection=True in the original code
                                         use_attn=config.utt_type == 'attn_rnn',
                                         vocab_size=self.vocab_size,
                                         embedding=self.embedding)

        self.ctx_encoder = EncoderRNN(self.utt_encoder.output_size,
                                      config.ctx_cell_size,
                                      0.0,
                                      config.dropout,
                                      config.num_layer,
                                      config.rnn_cell,
                                      variable_lengths=config.fix_batch)
        # FNN to get Y
        self.p_fc1 = nn.Linear(config.ctx_cell_size, config.ctx_cell_size)
        self.p_y = nn.Linear(config.ctx_cell_size, config.y_size * config.k)

        # connector
        self.c_init_connector = nn_lib.LinearConnector(config.y_size * config.k,
                                                       config.dec_cell_size,
                                                       config.rnn_cell == 'lstm')
        # decoder
        self.decoder = DecoderRNN(self.vocab_size, config.max_dec_len,
                                  config.embed_size, config.dec_cell_size,
                                  self.go_id, self.eos_id,
                                  n_layers=1, rnn_cell=config.rnn_cell,
                                  input_dropout_p=config.dropout,
                                  dropout_p=config.dropout,
                                  use_attention=config.use_attn,
                                  attn_size=config.dec_cell_size,
                                  attn_mode=config.attn_type,
                                  use_gpu=config.use_gpu,
                                  embedding=self.embedding)

        # force G(z,c) has z
        if config.use_attribute:
            self.attribute_loss = criterions.NLLEntropy(-100, config)

        self.cat_connector = nn_lib.GumbelConnector(config.use_gpu)
        self.greedy_cat_connector = nn_lib.GreedyConnector()
        self.nll_loss = criterions.NLLEntropy(self.rev_vocab[PAD], self.config)
        self.cat_kl_loss = criterions.CatKLLoss()
        self.log_uniform_y = Variable(torch.log(torch.ones(1) / config.k))
        self.entropy_loss = criterions.Entropy()

        if self.use_gpu:
            self.log_uniform_y = self.log_uniform_y.cuda()
        self.kl_w = 0.0