def __init__(self, corpus, config):
        super(DiVST, self).__init__(config)
        self.vocab = corpus.vocab
        self.rev_vocab = corpus.rev_vocab
        self.vocab_size = len(self.vocab)
        self.embed_size = config.embed_size
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab[BOS]
        self.eos_id = self.rev_vocab[EOS]
        self.num_layer = config.num_layer
        self.dropout = config.dropout
        self.enc_cell_size = config.enc_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.rnn_cell = config.rnn_cell
        self.max_dec_len = config.max_dec_len
        self.use_attn = config.use_attn
        self.beam_size = config.beam_size
        self.utt_type = config.utt_type
        self.bi_enc_cell = config.bi_enc_cell
        self.attn_type = config.attn_type
        self.enc_out_size = self.enc_cell_size*2 if self.bi_enc_cell else self.enc_cell_size


        # build model here
        self.embedding = nn.Embedding(self.vocab_size, self.embed_size,
                                      padding_idx=self.rev_vocab[PAD])

        self.x_encoder = EncoderRNN(self.embed_size, self.enc_cell_size,
                                    bidirection=self.bi_enc_cell,
                                    dropout_p=self.dropout,
                                    rnn_cell=self.rnn_cell,
                                    variable_lengths=False)

        self.q_y = nn.Linear(self.enc_out_size, config.y_size * config.k)
        self.cat_connector = nn_lib.GumbelConnector()
        self.dec_init_connector = nn_lib.LinearConnector(config.y_size * config.k,
                                                         self.dec_cell_size,
                                                         self.rnn_cell == 'lstm')


        self.prev_decoder = DecoderRNN(self.vocab_size, self.max_dec_len,
                                       self.embed_size, self.dec_cell_size,
                                       self.go_id, self.eos_id,
                                       n_layers=1, rnn_cell=self.rnn_cell,
                                       input_dropout_p=self.dropout,
                                       dropout_p=self.dropout,
                                       use_attention=self.use_attn,
                                       attn_size=self.enc_cell_size,
                                       attn_mode=self.attn_type,
                                       use_gpu=self.use_gpu,
                                       embedding=self.embedding)

        self.next_decoder = DecoderRNN(self.vocab_size, self.max_dec_len,
                                       self.embed_size, self.dec_cell_size,
                                       self.go_id, self.eos_id,
                                       n_layers=1, rnn_cell=self.rnn_cell,
                                       input_dropout_p=self.dropout,
                                       dropout_p=self.dropout,
                                       use_attention=self.use_attn,
                                       attn_size=self.enc_cell_size,
                                       attn_mode=self.attn_type,
                                       use_gpu=self.use_gpu,
                                       embedding=self.embedding)

        self.nll_loss = criterions.NLLEntropy(self.rev_vocab[PAD], self.config)
        self.cat_kl_loss = criterions.CatKLLoss()
        self.cross_ent_loss = criterions.CrossEntropyoss()
        self.entropy_loss = criterions.Entropy()
        self.log_uniform_y = Variable(torch.log(torch.ones(1) / config.k))
        if self.use_gpu:
            self.log_uniform_y = self.log_uniform_y.cuda()
        self.kl_w = 1.0
Пример #2
0
    def __init__(self, corpus, config):
        super(StED, self).__init__(config)
        self.vocab = corpus.vocab
        self.rev_vocab = corpus.rev_vocab
        self.vocab_size = len(self.vocab)
        self.go_id = self.rev_vocab[BOS]
        self.eos_id = self.rev_vocab[EOS]
        if not hasattr(config, "freeze_step"):
            config.freeze_step = 6000

        # build model here
        # word embeddings
        self.x_embedding = nn.Embedding(self.vocab_size, config.embed_size)

        # latent action learned
        self.x_encoder = EncoderRNN(config.embed_size, config.dec_cell_size,
                                    dropout_p=config.dropout,
                                    rnn_cell=config.rnn_cell,
                                    variable_lengths=False)

        self.q_y = nn.Linear(config.dec_cell_size, config.y_size * config.k)
        self.x_init_connector = nn_lib.LinearConnector(config.y_size * config.k,
                                                       config.dec_cell_size,
                                                       config.rnn_cell == 'lstm')
        # decoder
        self.prev_decoder = DecoderRNN(self.vocab_size, config.max_dec_len,
                                        config.embed_size, config.dec_cell_size,
                                        self.go_id, self.eos_id,
                                        n_layers=1, rnn_cell=config.rnn_cell,
                                        input_dropout_p=config.dropout,
                                        dropout_p=config.dropout,
                                        use_attention=False,
                                        use_gpu=config.use_gpu,
                                        embedding=self.x_embedding)

        self.next_decoder = DecoderRNN(self.vocab_size, config.max_dec_len,
                                        config.embed_size, config.dec_cell_size,
                                        self.go_id, self.eos_id,
                                        n_layers=1, rnn_cell=config.rnn_cell,
                                        input_dropout_p=config.dropout,
                                        dropout_p=config.dropout,
                                        use_attention=False,
                                        use_gpu=config.use_gpu,
                                        embedding=self.x_embedding)


        # Encoder-Decoder STARTS here
        self.embedding = nn.Embedding(self.vocab_size, config.embed_size,
                                      padding_idx=self.rev_vocab[PAD])

        self.utt_encoder = RnnUttEncoder(config.utt_cell_size, config.dropout,
                                         bidirection=False,
                                         #  bidirection=True in the original code
                                         use_attn=config.utt_type == 'attn_rnn',
                                         vocab_size=self.vocab_size,
                                         embedding=self.embedding)

        self.ctx_encoder = EncoderRNN(self.utt_encoder.output_size,
                                      config.ctx_cell_size,
                                      0.0,
                                      config.dropout,
                                      config.num_layer,
                                      config.rnn_cell,
                                      variable_lengths=config.fix_batch)
        # FNN to get Y
        self.p_fc1 = nn.Linear(config.ctx_cell_size, config.ctx_cell_size)
        self.p_y = nn.Linear(config.ctx_cell_size, config.y_size * config.k)

        # connector
        self.c_init_connector = nn_lib.LinearConnector(config.y_size * config.k,
                                                       config.dec_cell_size,
                                                       config.rnn_cell == 'lstm')
        # decoder
        self.decoder = DecoderRNN(self.vocab_size, config.max_dec_len,
                                  config.embed_size, config.dec_cell_size,
                                  self.go_id, self.eos_id,
                                  n_layers=1, rnn_cell=config.rnn_cell,
                                  input_dropout_p=config.dropout,
                                  dropout_p=config.dropout,
                                  use_attention=config.use_attn,
                                  attn_size=config.dec_cell_size,
                                  attn_mode=config.attn_type,
                                  use_gpu=config.use_gpu,
                                  embedding=self.embedding)

        # force G(z,c) has z
        if config.use_attribute:
            self.attribute_loss = criterions.NLLEntropy(-100, config)

        self.cat_connector = nn_lib.GumbelConnector(config.use_gpu)
        self.greedy_cat_connector = nn_lib.GreedyConnector()
        self.nll_loss = criterions.NLLEntropy(self.rev_vocab[PAD], self.config)
        self.cat_kl_loss = criterions.CatKLLoss()
        self.log_uniform_y = Variable(torch.log(torch.ones(1) / config.k))
        self.entropy_loss = criterions.Entropy()

        if self.use_gpu:
            self.log_uniform_y = self.log_uniform_y.cuda()
        self.kl_w = 0.0
Пример #3
0
    def __init__(self, corpus, config):
        super(DirVAE, self).__init__(config)
        self.vocab = corpus.vocab
        self.rev_vocab = corpus.rev_vocab
        self.vocab_size = len(self.vocab)
        self.embed_size = config.embed_size
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab[BOS]
        self.eos_id = self.rev_vocab[EOS]
        self.num_layer = config.num_layer
        self.dropout = config.dropout
        self.enc_cell_size = config.enc_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.rnn_cell = config.rnn_cell
        self.max_dec_len = config.max_dec_len
        self.use_attn = config.use_attn
        self.beam_size = config.beam_size
        self.utt_type = config.utt_type
        self.bi_enc_cell = config.bi_enc_cell
        self.attn_type = config.attn_type
        self.enc_out_size = self.enc_cell_size * 2 if self.bi_enc_cell else self.enc_cell_size
        self.full_kl_step = 4200.0

        # build model here
        self.embedding = nn.Embedding(self.vocab_size,
                                      self.embed_size,
                                      padding_idx=self.rev_vocab[PAD])

        self.x_encoder = EncoderRNN(self.embed_size,
                                    self.enc_cell_size,
                                    dropout_p=self.dropout,
                                    rnn_cell=self.rnn_cell,
                                    variable_lengths=self.config.fix_batch,
                                    bidirection=self.bi_enc_cell)

        # Dirichlet Topic Model prior
        self.h_dim = config.latent_size
        self.a = 1. * np.ones((1, self.h_dim)).astype(np.float32)
        prior_mean = torch.from_numpy(
            (np.log(self.a).T - np.mean(np.log(self.a), 1)).T)
        prior_var = torch.from_numpy(
            (((1.0 / self.a) * (1 - (2.0 / self.h_dim))).T +
             (1.0 / (self.h_dim * self.h_dim)) * np.sum(1.0 / self.a, 1)).T)
        prior_logvar = prior_var.log()

        self.register_buffer('prior_mean', prior_mean)
        self.register_buffer('prior_var', prior_var)
        self.register_buffer('prior_logvar', prior_logvar)

        self.logvar_fc = nn.Sequential(
            nn.Linear(self.enc_cell_size,
                      np.maximum(config.latent_size * 2, 100)), nn.Tanh(),
            nn.Linear(np.maximum(config.latent_size * 2, 100),
                      config.latent_size))
        self.mean_fc = nn.Sequential(
            nn.Linear(self.enc_cell_size,
                      np.maximum(config.latent_size * 2, 100)), nn.Tanh(),
            nn.Linear(np.maximum(config.latent_size * 2, 100),
                      config.latent_size))
        self.mean_bn = nn.BatchNorm1d(self.h_dim)  # bn for mean
        self.logvar_bn = nn.BatchNorm1d(self.h_dim)  # bn for logvar
        self.decoder_bn = nn.BatchNorm1d(self.vocab_size)

        self.logvar_bn.weight.requires_grad = False
        self.mean_bn.weight.requires_grad = False
        self.decoder_bn.weight.requires_grad = False

        self.logvar_bn.weight.fill_(1)
        self.mean_bn.weight.fill_(1)
        self.decoder_bn.weight.fill_(1)
        # self.q_y = nn.Linear(self.enc_out_size, self.h_dim)
        #self.cat_connector = nn_lib.GumbelConnector()

        # Prior for the Generation
        # self.z_mean  = nn.Sequential(
        #                 nn.Linear(self.enc_cell_size, np.maximum(config.latent_size * 2, 100)),
        #                 nn.Tanh(),
        #                 nn.Linear(np.maximum(config.latent_size * 2, 100), config.latent_size)
        #                 )

        # self.z_logvar = self.z_mean  = nn.Sequential(
        #                 nn.Linear(self.enc_cell_size, np.maximum(config.latent_size * 2, 100)),
        #                 nn.Tanh(),
        #                 nn.Linear(np.maximum(config.latent_size * 2, 100), config.latent_size)
        #                 )

        self.dec_init_connector = nn_lib.LinearConnector(
            config.latent_size,  #+ self.h_dim,
            self.dec_cell_size,
            self.rnn_cell == 'lstm',
            has_bias=False)

        self.decoder = DecoderRNN(self.vocab_size,
                                  self.max_dec_len,
                                  self.embed_size,
                                  self.dec_cell_size,
                                  self.go_id,
                                  self.eos_id,
                                  n_layers=1,
                                  rnn_cell=self.rnn_cell,
                                  input_dropout_p=self.dropout,
                                  dropout_p=self.dropout,
                                  use_attention=self.use_attn,
                                  attn_size=self.enc_cell_size,
                                  attn_mode=self.attn_type,
                                  use_gpu=self.use_gpu,
                                  embedding=self.embedding)

        self.nll_loss = criterions.NLLEntropy(self.rev_vocab[PAD], self.config)
        #self.cat_kl_loss = criterions.CatKLLoss()
        #self.cross_ent_loss = criterions.CrossEntropyoss()
        self.entropy_loss = criterions.Entropy()
        # self.log_py = nn.Parameter(torch.log(torch.ones(self.config.y_size,
        #                                                 self.config.k)/config.k),
        #                            requires_grad=True)
        # self.register_parameter('log_py', self.log_py)

        # self.log_uniform_y = Variable(torch.log(torch.ones(1) / config.k))
        # if self.use_gpu:
        #     self.log_uniform_y = self.log_uniform_y.cuda()

        # BOW loss
        self.bow_project = nn.Sequential(
            #nn.Linear(self.h_dim + config.latent_size, self.vocab_size),
            nn.Linear(config.latent_size, self.vocab_size),
            self.decoder_bn)

        self.kl_w = 0.0