예제 #1
0
    def __init__(self, corpus, config):
        super(HRED, self).__init__(config)

        self.vocab = corpus.vocab
        self.rev_vocab = corpus.rev_vocab
        self.vocab_size = len(self.vocab)
        self.go_id = self.rev_vocab[BOS]
        self.eos_id = self.rev_vocab[EOS]
        self.pad_id = self.rev_vocab[PAD]

        # build model here
        self.utt_encoder = ElmoUttEncoder(
            config,
            config.utt_cell_size,
            config.dropout,
            use_attn=config.utt_type == 'attn_rnn',
            feat_size=1)

        self.ctx_encoder = EncoderRNN(self.utt_encoder.output_size,
                                      config.ctx_cell_size,
                                      0.0,
                                      config.dropout,
                                      config.num_layer,
                                      config.rnn_cell,
                                      variable_lengths=False,
                                      bidirection=config.bi_ctx_cell)

        if config.bi_ctx_cell or config.num_layer > 1:
            self.connector = Bi2UniConnector(config.rnn_cell, config.num_layer,
                                             config.ctx_cell_size,
                                             config.dec_cell_size)
        else:
            self.connector = IdentityConnector()

        self.decoder = ElmoDecoderRNN(self.vocab_size,
                                      config.max_dec_len,
                                      self.utt_encoder.embed_size,
                                      config.dec_cell_size,
                                      self.go_id,
                                      self.eos_id,
                                      self.vocab,
                                      n_layers=1,
                                      rnn_cell=config.rnn_cell,
                                      input_dropout_p=config.dropout,
                                      dropout_p=config.dropout,
                                      use_attention=config.use_attn,
                                      embedding=self.utt_encoder.embedding,
                                      attn_size=self.ctx_encoder.output_size,
                                      attn_mode=config.attn_type,
                                      use_gpu=config.use_gpu)
        self.nll = criterions.NLLEntropy(self.pad_id, config)
    def __init__(self, corpus, config):
        super(PtrHRED, self).__init__(config, corpus)

        self.ctx_encoder = EncoderRNN(self.utt_encoder.output_size,
                                      config.ctx_cell_size,
                                      0.0,
                                      config.dropout,
                                      config.num_layer,
                                      config.rnn_cell,
                                      variable_lengths=False,
                                      bidirection=config.bi_ctx_cell)

        if config.bi_ctx_cell or config.num_layer > 1:
            self.connector = Bi2UniConnector(config.rnn_cell, config.num_layer,
                                             config.ctx_cell_size,
                                             config.dec_cell_size)
        else:
            self.connector = IdentityConnector()

        self.attn_size = self.ctx_encoder.output_size

        self.plain_embedding = nn.Embedding(self.vocab_size, config.embed_size)
        self.decoder = DecoderPointerGen(self.vocab_size,
                                         config.max_dec_len,
                                         config.embed_size,
                                         config.dec_cell_size,
                                         self.go_id,
                                         self.eos_id,
                                         n_layers=1,
                                         rnn_cell=config.rnn_cell,
                                         input_dropout_p=config.dropout,
                                         dropout_p=config.dropout,
                                         attn_size=self.attn_size,
                                         attn_mode=config.attn_type,
                                         use_gpu=config.use_gpu,
                                         embedding=self.plain_embedding)

        self.nll_loss = criterions.NLLEntropy(self.pad_id, config)