def __init__(self, embdim, ctxdim, hdim, num_layers=1, dropout:float=0., redhdim=None, **kw): super(StackLSTMTransition, self).__init__(**kw) self.redhdim = redhdim if redhdim is not None else hdim indim = embdim + ctxdim self.embdim, self.ctxdim = embdim, ctxdim self.indim, self.hdim, self.numlayers, self.dropoutp = indim, hdim, num_layers, dropout self.main_lstm = LSTMTransition(indim, hdim, num_layers=num_layers, dropout=dropout) self.reduce_lstm = LSTMTransition(embdim, hdim, num_layers=num_layers, dropout=dropout) self.reduce_lin = torch.nn.Linear(hdim, embdim) self.dropout = torch.nn.Dropout(dropout)
def __init__(self, xlmr, embdim, hdim, numlayers:int=1, dropout=0., sentence_encoder:SequenceEncoder=None, query_encoder:SequenceEncoder=None, feedatt=False, store_attn=True, **kw): super(BasicGenModel, self).__init__(**kw) self.xlmr = xlmr encoder_dim = self.xlmr.args.encoder_embed_dim decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) decoder_emb = TokenEmb(decoder_emb, rare_token_ids=query_encoder.vocab.rare_ids, rare_id=1) self.out_emb = decoder_emb dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0) decoder_rnn = LSTMTransition(dec_rnn_in_dim, hdim, numlayers, dropout=dropout) self.out_rnn = decoder_rnn decoder_out = PtrGenOutput(hdim + encoder_dim, vocab=query_encoder.vocab) decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab, str_action_re=None) self.out_lin = decoder_out self.att = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim), dropout=min(0.1, dropout)) self.enc_to_dec = torch.nn.ModuleList([torch.nn.Sequential( torch.nn.Linear(encoder_dim, hdim), torch.nn.Tanh() ) for _ in range(numlayers)]) self.feedatt = feedatt self.store_attn = store_attn self.reset_parameters()
def __init__(self, embdim, hdim, numlayers:int=1, dropout=0., zdim=None, sentence_encoder:SequenceEncoder=None, query_encoder:SequenceEncoder=None, feedatt=False, store_attn=True, minkl=0.05, **kw): super(BasicGenModel, self).__init__(**kw) self.minkl = minkl self.embdim, self.hdim, self.numlayers, self.dropout = embdim, hdim, numlayers, dropout self.zdim = embdim if zdim is None else zdim inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(), embdim, padding_idx=0) inpemb = TokenEmb(inpemb, rare_token_ids=sentence_encoder.vocab.rare_ids, rare_id=1) # _, covered_word_ids = load_pretrained_embeddings(inpemb.emb, sentence_encoder.vocab.D, # p="../../data/glove/glove300uncased") # load glove embeddings where possible into the inner embedding class # inpemb._do_rare(inpemb.rare_token_ids - covered_word_ids) self.inp_emb = inpemb encoder_dim = hdim encoder = LSTMEncoder(embdim, hdim // 2, num_layers=numlayers, dropout=dropout, bidirectional=True) # encoder = q.LSTMEncoder(embdim, *([encoder_dim // 2] * numlayers), bidir=True, dropout_in=dropout) self.inp_enc = encoder self.out_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) dec_rnn_in_dim = embdim + self.zdim + (encoder_dim if feedatt else 0) decoder_rnn = LSTMTransition(dec_rnn_in_dim, hdim, numlayers, dropout=dropout) self.out_rnn = decoder_rnn self.out_emb_vae = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) self.out_enc = LSTMEncoder(embdim, hdim //2, num_layers=numlayers, dropout=dropout, bidirectional=True) # self.out_mu = torch.nn.Sequential(torch.nn.Linear(embdim, hdim), torch.nn.Tanh(), torch.nn.Linear(hdim, self.zdim)) # self.out_logvar = torch.nn.Sequential(torch.nn.Linear(embdim, hdim), torch.nn.Tanh(), torch.nn.Linear(hdim, self.zdim)) self.out_mu = torch.nn.Sequential(torch.nn.Linear(hdim, self.zdim)) self.out_logvar = torch.nn.Sequential(torch.nn.Linear(hdim, self.zdim)) decoder_out = BasicGenOutput(hdim + encoder_dim, vocab=query_encoder.vocab) # decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab) self.out_lin = decoder_out self.att = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim), dropout=min(0.1, dropout)) self.enc_to_dec = torch.nn.ModuleList([torch.nn.Sequential( torch.nn.Linear(encoder_dim, hdim), torch.nn.Tanh() ) for _ in range(numlayers)]) self.feedatt = feedatt self.nocopy = True self.store_attn = store_attn self.reset_parameters()
def create_model(embdim=100, hdim=100, dropout=0., numlayers: int = 1, sentence_encoder: SequenceEncoder = None, query_encoder: SequenceEncoder = None, feedatt=False, nocopy=False): inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(), embdim, padding_idx=0) inpemb = TokenEmb(inpemb, rare_token_ids=sentence_encoder.vocab.rare_ids, rare_id=1) encoder_dim = hdim encoder = LSTMEncoder(embdim, hdim // 2, numlayers, bidirectional=True, dropout=dropout) # encoder = PytorchSeq2SeqWrapper( # torch.nn.LSTM(embdim, hdim, num_layers=numlayers, bidirectional=True, batch_first=True, # dropout=dropout)) decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) decoder_emb = TokenEmb(decoder_emb, rare_token_ids=query_encoder.vocab.rare_ids, rare_id=1) dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0) decoder_rnn = LSTMTransition(dec_rnn_in_dim, hdim, dropout=dropout) # decoder_out = BasicGenOutput(hdim + encoder_dim, query_encoder.vocab) decoder_out = PtrGenOutput(hdim + encoder_dim, out_vocab=query_encoder.vocab) decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab) attention = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim), dropout=min(0.0, dropout)) # attention = q.Attention(q.DotAttComp(), dropout=min(0.0, dropout)) enctodec = torch.nn.ModuleList([ torch.nn.Sequential(torch.nn.Linear(encoder_dim, hdim), torch.nn.Tanh()) for _ in range(numlayers) ]) model = BasicGenModel(inpemb, encoder, decoder_emb, decoder_rnn, decoder_out, attention, enc_to_dec=enctodec, feedatt=feedatt, nocopy=nocopy) return model
class StackLSTMTransition(TransitionModel): def __init__(self, embdim, ctxdim, hdim, num_layers=1, dropout:float=0., redhdim=None, **kw): super(StackLSTMTransition, self).__init__(**kw) self.redhdim = redhdim if redhdim is not None else hdim indim = embdim + ctxdim self.embdim, self.ctxdim = embdim, ctxdim self.indim, self.hdim, self.numlayers, self.dropoutp = indim, hdim, num_layers, dropout self.main_lstm = LSTMTransition(indim, hdim, num_layers=num_layers, dropout=dropout) self.reduce_lstm = LSTMTransition(embdim, hdim, num_layers=num_layers, dropout=dropout) self.reduce_lin = torch.nn.Linear(hdim, embdim) self.dropout = torch.nn.Dropout(dropout) def get_init_state(self, batsize, device=torch.device("cpu")): main_state = self.main_lstm.get_init_state(batsize, device) reduce_state = self.reduce_lstm.get_init_state(batsize, device) state = State() state.h = main_state.h state.c = main_state.c state.stack = np.array(range(batsize), dtype="object") for i in range(batsize): state.stack[i] = [] state.stack[i].append((main_state[i:i+1], reduce_state[i:i+1])) return state def forward(self, x, ctx, stack_actions, state): """ :param x: (batsize, embdim) :param ctx: (batsize, ctxdim) :param stack_actions: (batsize,) -1 for pop (")" token ), 0 for nothing, +1 for going deeper ("(" token ) :param state: :return: """ # gather main lstm states and do stack management # if action is push (+1), # push current state onto stack, with zero state for reducer # do main update with current state, do reducer update with zero state and input embedding # if action is pop (-1), # replace embedding with reducer state # set main state to last main stack state # pop stack frame # do main update # if action is zero (0), replace reducer state with updated reducer state # if stack has only one element, prevent pop actions stacklens = [len(stack_i) for stack_i in state.stack] stacklens = torch.tensor(stacklens, device=stack_actions.device) stackmask = (stacklens <= 1).long() - 1 stack_actions = torch.max(stack_actions, stackmask) # input vector is embedding if action is zero or push, else it is the reducer state if torch.any(stack_actions == -1).item() is True: red_states = [stack_i[-1][1] for stack_i in state.stack] red_states = red_states[0].merge(red_states) red_encodings = self.reduce_lin(red_states.h[:, -1]) mask = (stack_actions == -1).float()[:, None] _x = x * (1 - mask) + red_encodings * mask else: _x = x # main state stays the same if action is zero or push, if pop, it's set to last main state on stack if torch.any(stack_actions == -1).item() is True: main_states = [stack_i[-1][0] for stack_i in state.stack] main_states = main_states[0].merge(main_states) mask = (stack_actions == -1).float()[:, None, None] state.h = state.h * (1 - mask) + main_states.h * mask state.c = state.c * (1 - mask) + main_states.c * mask # stack management for i, action in enumerate(list(stack_actions.cpu().numpy())): if action == 1: # push current main state onto stack, with zero state for reducer main_state = self.main_lstm.get_init_state(1, device=state.h.device) main_state.h = state.h[i:i+1] main_state.c = state.c[i:i+1] reduce_state = self.reduce_lstm.get_init_state(1, device=state.h.device) state.stack[i].append((main_state, reduce_state)) elif action == -1: # pop stack frame state.stack[i].pop(-1) else: pass # reducer state is always the last reducer state on the stack reduce_states = [stack_i[-1][1] for stack_i in state.stack] reduce_states = reduce_states[0].merge(reduce_states) # main update y, main_state = self.main_lstm(torch.cat([_x, ctx], -1), state) state.h, state.c = main_state.h, main_state.c # reducer update reducer_y, reducer_state = self.reduce_lstm(_x, reduce_states) for i in range(len(state)): state.stack[i][-1] = (state.stack[i][-1][0], reducer_state[i]) return y, state