예제 #1
0
def create_model(hdim=128,
                 dropout=0.,
                 numlayers: int = 1,
                 numheads: int = 4,
                 sentence_encoder: SequenceEncoder = None,
                 query_encoder: SequenceEncoder = None,
                 feedatt=False,
                 maxtime=100):
    inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids() +
                                maxtime,
                                hdim,
                                padding_idx=0)
    inpemb = TokenEmb(inpemb,
                      rare_token_ids=sentence_encoder.vocab.rare_ids,
                      rare_id=1)
    tm_config = TransformerConfig(vocab_size=inpemb.emb.num_embeddings,
                                  num_attention_heads=numheads,
                                  num_hidden_layers=numlayers,
                                  hidden_size=hdim,
                                  intermediate_size=hdim * 4,
                                  hidden_dropout_prob=dropout)
    tm = Transformer(tm_config)
    tm.embeddings.word_embeddings = inpemb
    decoder_out = BasicGenOutput(hdim, query_encoder.vocab)
    model = NARTMModel(tm,
                       decoder_out,
                       maxinplen=maxtime,
                       maxoutlen=maxtime,
                       numinpids=sentence_encoder.vocab.number_of_ids())
    return model
예제 #2
0
    def __init__(self, embdim, hdim, numlayers:int=1, dropout=0., zdim=None,
                 sentence_encoder:SequenceEncoder=None,
                 query_encoder:SequenceEncoder=None,
                 feedatt=False, store_attn=True,
                 minkl=0.05, **kw):
        super(BasicGenModel, self).__init__(**kw)

        self.minkl = minkl

        self.embdim, self.hdim, self.numlayers, self.dropout = embdim, hdim, numlayers, dropout
        self.zdim = embdim if zdim is None else zdim

        inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(), embdim, padding_idx=0)
        inpemb = TokenEmb(inpemb, rare_token_ids=sentence_encoder.vocab.rare_ids, rare_id=1)
        # _, covered_word_ids = load_pretrained_embeddings(inpemb.emb, sentence_encoder.vocab.D,
        #                                                  p="../../data/glove/glove300uncased")  # load glove embeddings where possible into the inner embedding class
        # inpemb._do_rare(inpemb.rare_token_ids - covered_word_ids)
        self.inp_emb = inpemb

        encoder_dim = hdim
        encoder = LSTMEncoder(embdim, hdim // 2, num_layers=numlayers, dropout=dropout, bidirectional=True)
        # encoder = q.LSTMEncoder(embdim, *([encoder_dim // 2] * numlayers), bidir=True, dropout_in=dropout)
        self.inp_enc = encoder

        self.out_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0)

        dec_rnn_in_dim = embdim + self.zdim + (encoder_dim if feedatt else 0)
        decoder_rnn = LSTMTransition(dec_rnn_in_dim, hdim, numlayers, dropout=dropout)
        self.out_rnn = decoder_rnn
        self.out_emb_vae = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0)
        self.out_enc = LSTMEncoder(embdim, hdim //2, num_layers=numlayers, dropout=dropout, bidirectional=True)
        # self.out_mu = torch.nn.Sequential(torch.nn.Linear(embdim, hdim), torch.nn.Tanh(), torch.nn.Linear(hdim, self.zdim))
        # self.out_logvar = torch.nn.Sequential(torch.nn.Linear(embdim, hdim), torch.nn.Tanh(), torch.nn.Linear(hdim, self.zdim))
        self.out_mu = torch.nn.Sequential(torch.nn.Linear(hdim, self.zdim))
        self.out_logvar = torch.nn.Sequential(torch.nn.Linear(hdim, self.zdim))

        decoder_out = BasicGenOutput(hdim + encoder_dim, vocab=query_encoder.vocab)
        # decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab)
        self.out_lin = decoder_out

        self.att = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim), dropout=min(0.1, dropout))

        self.enc_to_dec = torch.nn.ModuleList([torch.nn.Sequential(
            torch.nn.Linear(encoder_dim, hdim),
            torch.nn.Tanh()
        ) for _ in range(numlayers)])

        self.feedatt = feedatt
        self.nocopy = True

        self.store_attn = store_attn

        self.reset_parameters()
예제 #3
0
    def __init__(self,
                 embdim,
                 hdim,
                 numlayers: int = 1,
                 dropout=0.,
                 sentence_encoder: SequenceEncoder = None,
                 query_encoder: SequenceEncoder = None,
                 feedatt=False,
                 store_attn=True,
                 vib_init=False,
                 vib_enc=False,
                 **kw):
        super(BasicGenModel_VIB, self).__init__(**kw)

        inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(),
                                    embdim,
                                    padding_idx=0)

        # _, covered_word_ids = load_pretrained_embeddings(inpemb.emb, sentence_encoder.vocab.D,
        #                                                  p="../../data/glove/glove300uncased")  # load glove embeddings where possible into the inner embedding class
        # inpemb._do_rare(inpemb.rare_token_ids - covered_word_ids)
        self.inp_emb = inpemb

        encoder_dim = hdim * 2
        encoder = GRUEncoder(embdim,
                             hdim,
                             num_layers=numlayers,
                             dropout=dropout,
                             bidirectional=True)
        # encoder = q.LSTMEncoder(embdim, *([encoder_dim // 2] * numlayers), bidir=True, dropout_in=dropout)
        self.inp_enc = encoder

        decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(),
                                         embdim,
                                         padding_idx=0)
        self.out_emb = decoder_emb

        dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0)
        decoder_rnn = GRUTransition(dec_rnn_in_dim,
                                    hdim,
                                    numlayers,
                                    dropout=dropout)
        self.out_rnn = decoder_rnn

        decoder_out = BasicGenOutput(hdim + encoder_dim,
                                     vocab=query_encoder.vocab)
        # decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab)
        self.out_lin = decoder_out

        self.att = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim),
                               dropout=min(0.1, dropout))

        self.enc_to_dec = torch.nn.ModuleList([
            torch.nn.Sequential(torch.nn.Linear(encoder_dim, hdim),
                                torch.nn.Tanh()) for _ in range(numlayers)
        ])

        self.feedatt = feedatt
        self.nocopy = True

        self.store_attn = store_attn

        # VIBs
        self.vib_init = torch.nn.ModuleList(
            [VIB(encoder_dim) for _ in range(numlayers)]) if vib_init else None
        self.vib_enc = VIB_seq(encoder_dim) if vib_enc else None

        self.reset_parameters()
예제 #4
0
    def __init__(self,
                 embdim,
                 hdim,
                 numlayers: int = 1,
                 dropout=0.,
                 sentence_encoder: SequenceEncoder = None,
                 query_encoder: SequenceEncoder = None,
                 feedatt=False,
                 store_attn=True,
                 **kw):
        super(BasicGenModel, self).__init__(**kw)

        inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(),
                                    300,
                                    padding_idx=0)
        inpemb = TokenEmb(inpemb,
                          adapt_dims=(300, embdim),
                          rare_token_ids=sentence_encoder.vocab.rare_ids,
                          rare_id=1)
        _, covered_word_ids = load_pretrained_embeddings(
            inpemb.emb,
            sentence_encoder.vocab.D,
            p="../../data/glove/glove300uncased"
        )  # load glove embeddings where possible into the inner embedding class
        inpemb._do_rare(inpemb.rare_token_ids - covered_word_ids)
        self.inp_emb = inpemb

        encoder_dim = hdim
        encoder = q.LSTMEncoder(embdim,
                                *([encoder_dim // 2] * numlayers),
                                bidir=True,
                                dropout_in=dropout)
        self.inp_enc = encoder

        decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(),
                                         embdim,
                                         padding_idx=0)
        decoder_emb = TokenEmb(decoder_emb,
                               rare_token_ids=query_encoder.vocab.rare_ids,
                               rare_id=1)
        self.out_emb = decoder_emb

        dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0)
        decoder_rnn = [torch.nn.LSTMCell(dec_rnn_in_dim, hdim)]
        for i in range(numlayers - 1):
            decoder_rnn.append(torch.nn.LSTMCell(hdim, hdim))
        decoder_rnn = LSTMCellTransition(*decoder_rnn, dropout=dropout)
        self.out_rnn = decoder_rnn

        decoder_out = BasicGenOutput(hdim + encoder_dim,
                                     vocab=query_encoder.vocab)
        # decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab)
        self.out_lin = decoder_out

        self.att = q.Attention(q.MatMulDotAttComp(hdim, encoder_dim))

        self.enc_to_dec = torch.nn.Sequential(
            torch.nn.Linear(encoder_dim, hdim), torch.nn.Tanh())

        self.feedatt = feedatt
        self.nocopy = True

        self.store_attn = store_attn