コード例 #1
0
 def __init__(self, embdim, ctxdim, hdim, num_layers=1, dropout:float=0., redhdim=None, **kw):
     super(StackLSTMTransition, self).__init__(**kw)
     self.redhdim = redhdim if redhdim is not None else hdim
     indim = embdim + ctxdim
     self.embdim, self.ctxdim = embdim, ctxdim
     self.indim, self.hdim, self.numlayers, self.dropoutp = indim, hdim, num_layers, dropout
     self.main_lstm = LSTMTransition(indim, hdim, num_layers=num_layers, dropout=dropout)
     self.reduce_lstm = LSTMTransition(embdim, hdim, num_layers=num_layers, dropout=dropout)
     self.reduce_lin = torch.nn.Linear(hdim, embdim)
     self.dropout = torch.nn.Dropout(dropout)
コード例 #2
0
    def __init__(self, xlmr, embdim, hdim, numlayers:int=1, dropout=0.,
                 sentence_encoder:SequenceEncoder=None,
                 query_encoder:SequenceEncoder=None,
                 feedatt=False, store_attn=True, **kw):
        super(BasicGenModel, self).__init__(**kw)

        self.xlmr = xlmr
        encoder_dim = self.xlmr.args.encoder_embed_dim

        decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0)
        decoder_emb = TokenEmb(decoder_emb, rare_token_ids=query_encoder.vocab.rare_ids, rare_id=1)
        self.out_emb = decoder_emb

        dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0)
        decoder_rnn = LSTMTransition(dec_rnn_in_dim, hdim, numlayers, dropout=dropout)
        self.out_rnn = decoder_rnn

        decoder_out = PtrGenOutput(hdim + encoder_dim, vocab=query_encoder.vocab)
        decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab, str_action_re=None)
        self.out_lin = decoder_out

        self.att = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim), dropout=min(0.1, dropout))

        self.enc_to_dec = torch.nn.ModuleList([torch.nn.Sequential(
            torch.nn.Linear(encoder_dim, hdim),
            torch.nn.Tanh()
        ) for _ in range(numlayers)])

        self.feedatt = feedatt

        self.store_attn = store_attn

        self.reset_parameters()
コード例 #3
0
    def __init__(self, embdim, hdim, numlayers:int=1, dropout=0., zdim=None,
                 sentence_encoder:SequenceEncoder=None,
                 query_encoder:SequenceEncoder=None,
                 feedatt=False, store_attn=True,
                 minkl=0.05, **kw):
        super(BasicGenModel, self).__init__(**kw)

        self.minkl = minkl

        self.embdim, self.hdim, self.numlayers, self.dropout = embdim, hdim, numlayers, dropout
        self.zdim = embdim if zdim is None else zdim

        inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(), embdim, padding_idx=0)
        inpemb = TokenEmb(inpemb, rare_token_ids=sentence_encoder.vocab.rare_ids, rare_id=1)
        # _, covered_word_ids = load_pretrained_embeddings(inpemb.emb, sentence_encoder.vocab.D,
        #                                                  p="../../data/glove/glove300uncased")  # load glove embeddings where possible into the inner embedding class
        # inpemb._do_rare(inpemb.rare_token_ids - covered_word_ids)
        self.inp_emb = inpemb

        encoder_dim = hdim
        encoder = LSTMEncoder(embdim, hdim // 2, num_layers=numlayers, dropout=dropout, bidirectional=True)
        # encoder = q.LSTMEncoder(embdim, *([encoder_dim // 2] * numlayers), bidir=True, dropout_in=dropout)
        self.inp_enc = encoder

        self.out_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0)

        dec_rnn_in_dim = embdim + self.zdim + (encoder_dim if feedatt else 0)
        decoder_rnn = LSTMTransition(dec_rnn_in_dim, hdim, numlayers, dropout=dropout)
        self.out_rnn = decoder_rnn
        self.out_emb_vae = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0)
        self.out_enc = LSTMEncoder(embdim, hdim //2, num_layers=numlayers, dropout=dropout, bidirectional=True)
        # self.out_mu = torch.nn.Sequential(torch.nn.Linear(embdim, hdim), torch.nn.Tanh(), torch.nn.Linear(hdim, self.zdim))
        # self.out_logvar = torch.nn.Sequential(torch.nn.Linear(embdim, hdim), torch.nn.Tanh(), torch.nn.Linear(hdim, self.zdim))
        self.out_mu = torch.nn.Sequential(torch.nn.Linear(hdim, self.zdim))
        self.out_logvar = torch.nn.Sequential(torch.nn.Linear(hdim, self.zdim))

        decoder_out = BasicGenOutput(hdim + encoder_dim, vocab=query_encoder.vocab)
        # decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab)
        self.out_lin = decoder_out

        self.att = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim), dropout=min(0.1, dropout))

        self.enc_to_dec = torch.nn.ModuleList([torch.nn.Sequential(
            torch.nn.Linear(encoder_dim, hdim),
            torch.nn.Tanh()
        ) for _ in range(numlayers)])

        self.feedatt = feedatt
        self.nocopy = True

        self.store_attn = store_attn

        self.reset_parameters()
コード例 #4
0
def create_model(embdim=100,
                 hdim=100,
                 dropout=0.,
                 numlayers: int = 1,
                 sentence_encoder: SequenceEncoder = None,
                 query_encoder: SequenceEncoder = None,
                 feedatt=False,
                 nocopy=False):
    inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(),
                                embdim,
                                padding_idx=0)
    inpemb = TokenEmb(inpemb,
                      rare_token_ids=sentence_encoder.vocab.rare_ids,
                      rare_id=1)
    encoder_dim = hdim
    encoder = LSTMEncoder(embdim,
                          hdim // 2,
                          numlayers,
                          bidirectional=True,
                          dropout=dropout)
    # encoder = PytorchSeq2SeqWrapper(
    #     torch.nn.LSTM(embdim, hdim, num_layers=numlayers, bidirectional=True, batch_first=True,
    #                   dropout=dropout))
    decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(),
                                     embdim,
                                     padding_idx=0)
    decoder_emb = TokenEmb(decoder_emb,
                           rare_token_ids=query_encoder.vocab.rare_ids,
                           rare_id=1)
    dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0)
    decoder_rnn = LSTMTransition(dec_rnn_in_dim, hdim, dropout=dropout)
    # decoder_out = BasicGenOutput(hdim + encoder_dim, query_encoder.vocab)
    decoder_out = PtrGenOutput(hdim + encoder_dim,
                               out_vocab=query_encoder.vocab)
    decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab)
    attention = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim),
                            dropout=min(0.0, dropout))
    # attention = q.Attention(q.DotAttComp(), dropout=min(0.0, dropout))
    enctodec = torch.nn.ModuleList([
        torch.nn.Sequential(torch.nn.Linear(encoder_dim, hdim),
                            torch.nn.Tanh()) for _ in range(numlayers)
    ])
    model = BasicGenModel(inpemb,
                          encoder,
                          decoder_emb,
                          decoder_rnn,
                          decoder_out,
                          attention,
                          enc_to_dec=enctodec,
                          feedatt=feedatt,
                          nocopy=nocopy)
    return model
コード例 #5
0
class StackLSTMTransition(TransitionModel):
    def __init__(self, embdim, ctxdim, hdim, num_layers=1, dropout:float=0., redhdim=None, **kw):
        super(StackLSTMTransition, self).__init__(**kw)
        self.redhdim = redhdim if redhdim is not None else hdim
        indim = embdim + ctxdim
        self.embdim, self.ctxdim = embdim, ctxdim
        self.indim, self.hdim, self.numlayers, self.dropoutp = indim, hdim, num_layers, dropout
        self.main_lstm = LSTMTransition(indim, hdim, num_layers=num_layers, dropout=dropout)
        self.reduce_lstm = LSTMTransition(embdim, hdim, num_layers=num_layers, dropout=dropout)
        self.reduce_lin = torch.nn.Linear(hdim, embdim)
        self.dropout = torch.nn.Dropout(dropout)

    def get_init_state(self, batsize, device=torch.device("cpu")):
        main_state = self.main_lstm.get_init_state(batsize, device)
        reduce_state = self.reduce_lstm.get_init_state(batsize, device)
        state = State()
        state.h = main_state.h
        state.c = main_state.c
        state.stack = np.array(range(batsize), dtype="object")
        for i in range(batsize):
            state.stack[i] = []
            state.stack[i].append((main_state[i:i+1], reduce_state[i:i+1]))
        return state

    def forward(self, x, ctx, stack_actions, state):
        """
        :param x:               (batsize, embdim)
        :param ctx:             (batsize, ctxdim)
        :param stack_actions:   (batsize,) -1 for pop (")" token ), 0 for nothing, +1 for going deeper ("(" token )
        :param state:
        :return:
        """
        # gather main lstm states and do stack management
        # if action is push (+1),
        #       push current state onto stack, with zero state for reducer
        #       do main update with current state, do reducer update with zero state and input embedding
        # if action is pop (-1),
        #       replace embedding with reducer state
        #       set main state to last main stack state
        #       pop stack frame
        #       do main update
        # if action is zero (0), replace reducer state with updated reducer state

        # if stack has only one element, prevent pop actions
        stacklens = [len(stack_i) for stack_i in state.stack]
        stacklens = torch.tensor(stacklens, device=stack_actions.device)
        stackmask = (stacklens <= 1).long() - 1
        stack_actions = torch.max(stack_actions, stackmask)

        # input vector is embedding if action is zero or push, else it is the reducer state
        if torch.any(stack_actions == -1).item() is True:
            red_states = [stack_i[-1][1] for stack_i in state.stack]
            red_states = red_states[0].merge(red_states)
            red_encodings = self.reduce_lin(red_states.h[:, -1])
            mask = (stack_actions == -1).float()[:, None]
            _x = x * (1 - mask) + red_encodings * mask
        else:
            _x = x

        # main state stays the same if action is zero or push, if pop, it's set to last main state on stack
        if torch.any(stack_actions == -1).item() is True:
            main_states = [stack_i[-1][0] for stack_i in state.stack]
            main_states = main_states[0].merge(main_states)
            mask = (stack_actions == -1).float()[:, None, None]
            state.h = state.h * (1 - mask) + main_states.h * mask
            state.c = state.c * (1 - mask) + main_states.c * mask

        # stack management
        for i, action in enumerate(list(stack_actions.cpu().numpy())):
            if action == 1: # push current main state onto stack, with zero state for reducer
                main_state = self.main_lstm.get_init_state(1, device=state.h.device)
                main_state.h = state.h[i:i+1]
                main_state.c = state.c[i:i+1]
                reduce_state = self.reduce_lstm.get_init_state(1, device=state.h.device)
                state.stack[i].append((main_state, reduce_state))
            elif action == -1:  # pop stack frame
                state.stack[i].pop(-1)
            else:
                pass

        # reducer state is always the last reducer state on the stack
        reduce_states = [stack_i[-1][1] for stack_i in state.stack]
        reduce_states = reduce_states[0].merge(reduce_states)

        # main update
        y, main_state = self.main_lstm(torch.cat([_x, ctx], -1), state)
        state.h, state.c = main_state.h, main_state.c

        # reducer update
        reducer_y, reducer_state = self.reduce_lstm(_x, reduce_states)
        for i in range(len(state)):
            state.stack[i][-1] = (state.stack[i][-1][0], reducer_state[i])

        return y, state