def create_model(hdim=128, dropout=0., numlayers: int = 1, numheads: int = 4, sentence_encoder: SequenceEncoder = None, query_encoder: SequenceEncoder = None, feedatt=False, maxtime=100): inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids() + maxtime, hdim, padding_idx=0) inpemb = TokenEmb(inpemb, rare_token_ids=sentence_encoder.vocab.rare_ids, rare_id=1) tm_config = TransformerConfig(vocab_size=inpemb.emb.num_embeddings, num_attention_heads=numheads, num_hidden_layers=numlayers, hidden_size=hdim, intermediate_size=hdim * 4, hidden_dropout_prob=dropout) tm = Transformer(tm_config) tm.embeddings.word_embeddings = inpemb decoder_out = BasicGenOutput(hdim, query_encoder.vocab) model = NARTMModel(tm, decoder_out, maxinplen=maxtime, maxoutlen=maxtime, numinpids=sentence_encoder.vocab.number_of_ids()) return model
def __init__(self, embdim, hdim, numlayers:int=1, dropout=0., zdim=None, sentence_encoder:SequenceEncoder=None, query_encoder:SequenceEncoder=None, feedatt=False, store_attn=True, minkl=0.05, **kw): super(BasicGenModel, self).__init__(**kw) self.minkl = minkl self.embdim, self.hdim, self.numlayers, self.dropout = embdim, hdim, numlayers, dropout self.zdim = embdim if zdim is None else zdim inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(), embdim, padding_idx=0) inpemb = TokenEmb(inpemb, rare_token_ids=sentence_encoder.vocab.rare_ids, rare_id=1) # _, covered_word_ids = load_pretrained_embeddings(inpemb.emb, sentence_encoder.vocab.D, # p="../../data/glove/glove300uncased") # load glove embeddings where possible into the inner embedding class # inpemb._do_rare(inpemb.rare_token_ids - covered_word_ids) self.inp_emb = inpemb encoder_dim = hdim encoder = LSTMEncoder(embdim, hdim // 2, num_layers=numlayers, dropout=dropout, bidirectional=True) # encoder = q.LSTMEncoder(embdim, *([encoder_dim // 2] * numlayers), bidir=True, dropout_in=dropout) self.inp_enc = encoder self.out_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) dec_rnn_in_dim = embdim + self.zdim + (encoder_dim if feedatt else 0) decoder_rnn = LSTMTransition(dec_rnn_in_dim, hdim, numlayers, dropout=dropout) self.out_rnn = decoder_rnn self.out_emb_vae = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) self.out_enc = LSTMEncoder(embdim, hdim //2, num_layers=numlayers, dropout=dropout, bidirectional=True) # self.out_mu = torch.nn.Sequential(torch.nn.Linear(embdim, hdim), torch.nn.Tanh(), torch.nn.Linear(hdim, self.zdim)) # self.out_logvar = torch.nn.Sequential(torch.nn.Linear(embdim, hdim), torch.nn.Tanh(), torch.nn.Linear(hdim, self.zdim)) self.out_mu = torch.nn.Sequential(torch.nn.Linear(hdim, self.zdim)) self.out_logvar = torch.nn.Sequential(torch.nn.Linear(hdim, self.zdim)) decoder_out = BasicGenOutput(hdim + encoder_dim, vocab=query_encoder.vocab) # decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab) self.out_lin = decoder_out self.att = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim), dropout=min(0.1, dropout)) self.enc_to_dec = torch.nn.ModuleList([torch.nn.Sequential( torch.nn.Linear(encoder_dim, hdim), torch.nn.Tanh() ) for _ in range(numlayers)]) self.feedatt = feedatt self.nocopy = True self.store_attn = store_attn self.reset_parameters()
def __init__(self, embdim, hdim, numlayers: int = 1, dropout=0., sentence_encoder: SequenceEncoder = None, query_encoder: SequenceEncoder = None, feedatt=False, store_attn=True, vib_init=False, vib_enc=False, **kw): super(BasicGenModel_VIB, self).__init__(**kw) inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(), embdim, padding_idx=0) # _, covered_word_ids = load_pretrained_embeddings(inpemb.emb, sentence_encoder.vocab.D, # p="../../data/glove/glove300uncased") # load glove embeddings where possible into the inner embedding class # inpemb._do_rare(inpemb.rare_token_ids - covered_word_ids) self.inp_emb = inpemb encoder_dim = hdim * 2 encoder = GRUEncoder(embdim, hdim, num_layers=numlayers, dropout=dropout, bidirectional=True) # encoder = q.LSTMEncoder(embdim, *([encoder_dim // 2] * numlayers), bidir=True, dropout_in=dropout) self.inp_enc = encoder decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) self.out_emb = decoder_emb dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0) decoder_rnn = GRUTransition(dec_rnn_in_dim, hdim, numlayers, dropout=dropout) self.out_rnn = decoder_rnn decoder_out = BasicGenOutput(hdim + encoder_dim, vocab=query_encoder.vocab) # decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab) self.out_lin = decoder_out self.att = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim), dropout=min(0.1, dropout)) self.enc_to_dec = torch.nn.ModuleList([ torch.nn.Sequential(torch.nn.Linear(encoder_dim, hdim), torch.nn.Tanh()) for _ in range(numlayers) ]) self.feedatt = feedatt self.nocopy = True self.store_attn = store_attn # VIBs self.vib_init = torch.nn.ModuleList( [VIB(encoder_dim) for _ in range(numlayers)]) if vib_init else None self.vib_enc = VIB_seq(encoder_dim) if vib_enc else None self.reset_parameters()
def __init__(self, embdim, hdim, numlayers: int = 1, dropout=0., sentence_encoder: SequenceEncoder = None, query_encoder: SequenceEncoder = None, feedatt=False, store_attn=True, **kw): super(BasicGenModel, self).__init__(**kw) inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(), 300, padding_idx=0) inpemb = TokenEmb(inpemb, adapt_dims=(300, embdim), rare_token_ids=sentence_encoder.vocab.rare_ids, rare_id=1) _, covered_word_ids = load_pretrained_embeddings( inpemb.emb, sentence_encoder.vocab.D, p="../../data/glove/glove300uncased" ) # load glove embeddings where possible into the inner embedding class inpemb._do_rare(inpemb.rare_token_ids - covered_word_ids) self.inp_emb = inpemb encoder_dim = hdim encoder = q.LSTMEncoder(embdim, *([encoder_dim // 2] * numlayers), bidir=True, dropout_in=dropout) self.inp_enc = encoder decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) decoder_emb = TokenEmb(decoder_emb, rare_token_ids=query_encoder.vocab.rare_ids, rare_id=1) self.out_emb = decoder_emb dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0) decoder_rnn = [torch.nn.LSTMCell(dec_rnn_in_dim, hdim)] for i in range(numlayers - 1): decoder_rnn.append(torch.nn.LSTMCell(hdim, hdim)) decoder_rnn = LSTMCellTransition(*decoder_rnn, dropout=dropout) self.out_rnn = decoder_rnn decoder_out = BasicGenOutput(hdim + encoder_dim, vocab=query_encoder.vocab) # decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab) self.out_lin = decoder_out self.att = q.Attention(q.MatMulDotAttComp(hdim, encoder_dim)) self.enc_to_dec = torch.nn.Sequential( torch.nn.Linear(encoder_dim, hdim), torch.nn.Tanh()) self.feedatt = feedatt self.nocopy = True self.store_attn = store_attn