Beispiel #1
0
    def forward(self, inputs, hidden=None):
        if isinstance(inputs, PackedSequence):
            bsizes = inputs.batch_sizes
            max_batch = int(bsizes[0])
            emb = PackedSequence(self.embedding_dropout(
                self.embedder(inputs.data)), bsizes)
            # Get padding mask
            time_dim = 1 if self.batch_first else 0
            range_batch = torch.arange(0, max_batch,
                                       dtype=bsizes.dtype,
                                       device=bsizes.device)
            range_batch = range_batch.unsqueeze(time_dim)
            bsizes = bsizes.unsqueeze(1 - time_dim)
            padding_mask = (bsizes - range_batch).le(0)
        else:
            padding_mask = inputs.eq(PAD)
            emb = self.embedding_dropout(self.embedder(inputs))
        outputs, hidden_t = self.rnn(emb, hidden)
        if isinstance(inputs, PackedSequence):
            outputs = unpack(outputs)[0]
        outputs = self.dropout(outputs)
        if hasattr(self, 'context_transform'):
            context = self.context_transform(outputs)
        else:
            context = None

        if hasattr(self, 'hidden_transform'):
            hidden_t = self.hidden_transform(hidden_t)

        state = State(outputs=outputs, hidden=hidden_t, context=context,
                      mask=padding_mask, batch_first=self.batch_first)
        return state
Beispiel #2
0
 def forward(self, inputs, lengths, hidden=None):
     lens, indices = torch.sort(inputs.data.new(lengths).long(), 0, True)
     inputs = inputs[indices] if self.batch_first else inputs[:, indices] 
     outputs, (h, c) = self.rnn(pack(inputs, lens.tolist(), 
         batch_first=self.batch_first), hidden)
     outputs = unpack(outputs, batch_first=self.batch_first)[0]
     _, _indices = torch.sort(indices, 0)
     outputs = outputs[_indices] if self.batch_first else outputs[:, _indices]
     h, c = h[:, _indices, :], h[:, _indices, :]
     return outputs, (h, c)
Beispiel #3
0
 def forward(self, enc_input, hidden=None):
     if isinstance(enc_input, tuple):
         # Lengths data is wrapped inside a Variable.
         lengths = enc_input[1].data.view(-1).tolist()
         emb = pack(self.embedding(enc_input[0]), lengths)
     else:
         emb = self.embedding(enc_input)
     outputs, hidden_t = self.rnn(emb, hidden)
     if isinstance(enc_input, tuple):
         outputs = unpack(outputs)[0]
     return outputs, hidden_t
Beispiel #4
0
    def forward(self, inputs, state):
        context, hidden = state.context, state.hidden
        if isinstance(inputs, PackedSequence):
            emb = PackedSequence(self.embedding_dropout(
                self.embedder(inputs.data)), inputs.batch_size)
        else:
            emb = self.embedding_dropout(self.embedder(inputs))
        x, hidden_t = self.rnn(emb, hidden)
        if isinstance(inputs, PackedSequence):
            x = unpack(x)[0]
        x = self.dropout(x)

        x = self.classifier(x)
        return x, State(hidden=hidden_t, context=context, batch_first=self.batch_first)
Beispiel #5
0
    def forward(self, src, lengths=None):
        """See :func:`onmt.encoders.encoder.EncoderBase.forward()`"""
        batch_size, _, nfft, t = src.size()
        src = src.transpose(0, 1).transpose(0, 3).contiguous() \
                 .view(t, batch_size, nfft)
        orig_lengths = lengths
        lengths = lengths.view(-1).tolist()

        for l in range(self.enc_layers):
            rnn = getattr(self, 'rnn_%d' % l)
            pool = getattr(self, 'pool_%d' % l)
            batchnorm = getattr(self, 'batchnorm_%d' % l)
            stride = self.enc_pooling[l]
            packed_emb = pack(src, lengths)
            memory_bank, tmp = rnn(packed_emb)
            memory_bank = unpack(memory_bank)[0]
            t, _, _ = memory_bank.size()
            memory_bank = memory_bank.transpose(0, 2)
            memory_bank = pool(memory_bank)
            lengths = [int(math.floor((length - stride) / stride + 1))
                       for length in lengths]
            memory_bank = memory_bank.transpose(0, 2)
            src = memory_bank
            t, _, num_feat = src.size()
            src = batchnorm(src.contiguous().view(-1, num_feat))
            src = src.view(t, -1, num_feat)
            if self.dropout and l + 1 != self.enc_layers:
                src = self.dropout(src)

        memory_bank = memory_bank.contiguous().view(-1, memory_bank.size(2))
        memory_bank = self.W(memory_bank).view(-1, batch_size,
                                               self.dec_rnn_size)

        state = memory_bank.new_full((self.dec_layers * self.num_directions,
                                      batch_size, self.dec_rnn_size_real), 0)
        if self.rnn_type == 'LSTM':
            # The encoder hidden is  (layers*directions) x batch x dim.
            encoder_final = (state, state)
        else:
            encoder_final = state
        return encoder_final, memory_bank, orig_lengths.new_tensor(lengths)
Beispiel #6
0
    def forward(self, src, lengths=None, encoder_state=None):
        "See :obj:`EncoderBase.forward()`"
        self._check_args(src, lengths, encoder_state)

        emb = self.embeddings(src)
        s_len, batch, emb_dim = emb.size()

        packed_emb = emb
        if lengths is not None and not self.no_pack_padded_seq:
            # Lengths data is wrapped inside a Variable.
            lengths = lengths.view(-1).tolist()
            packed_emb = pack(emb, lengths)

        memory_bank, encoder_final = self.rnn(packed_emb, encoder_state)

        if lengths is not None and not self.no_pack_padded_seq:
            memory_bank = unpack(memory_bank)[0]

        if self.use_bridge:
            encoder_final = self._bridge(encoder_final)
        return encoder_final, memory_bank
Beispiel #7
0
    def forward(self, lang, src, lengths=None):
        """See :func:`EncoderBase.forward()`"""
        self._check_args(src, lengths)

        emb = self.embeddings[lang](src)
        # s_len, batch, emb_dim = emb.size()

        packed_emb = emb
        if lengths is not None and not self.no_pack_padded_seq:
            # Lengths data is wrapped inside a Tensor.
            lengths_list = lengths.view(-1).tolist()
            packed_emb = pack(emb, lengths_list)

        memory_bank, encoder_final = self.rnn(packed_emb)

        if lengths is not None and not self.no_pack_padded_seq:
            memory_bank = unpack(memory_bank)[0]

        if self.use_bridge:
            encoder_final = self._bridge(encoder_final)
        return encoder_final, memory_bank, lengths
Beispiel #8
0
    def forward(self, src, lengths=None, encoder_state=None):
        "See :obj:`EncoderBase.forward()`"
        self._check_args(src, lengths, encoder_state)

        emb = self.embeddings(src)
        s_len, batch, emb_dim = emb.size()

        packed_emb = emb
        if lengths is not None and not self.no_pack_padded_seq:
            # Lengths data is wrapped inside a Variable.
            lengths = lengths.view(-1).tolist()
            packed_emb = pack(emb, lengths)

        memory_bank, encoder_final = self.rnn(packed_emb, encoder_state)

        if lengths is not None and not self.no_pack_padded_seq:
            memory_bank = unpack(memory_bank)[0]

        if self.use_bridge:
            encoder_final = self._bridge(encoder_final)
        return encoder_final, memory_bank
Beispiel #9
0
    def forward(self, input, lengths=None, hidden=None, FLAG=True):
        "See :obj:`EncoderBase.forward()`"
        self._check_args(input, lengths, hidden)

        if FLAG:
            emb = self.embeddings(input)
            s_len, batch, emb_dim = emb.size()
        else:
            emb = input
        packed_emb = emb
        if lengths is not None and not self.no_pack_padded_seq:
            # Lengths data is wrapped inside a Variable.
            lengths = lengths.view(-1).tolist()
            packed_emb = pack(emb, lengths)

        outputs, hidden_t = self.rnn(packed_emb, hidden)

        if lengths is not None and not self.no_pack_padded_seq:
            outputs = unpack(outputs)[0]

        return hidden_t, outputs
Beispiel #10
0
    def forward(self, emb):
        '''
        원문을 인코딩하는 영역
        - time-step을 신경쓸 필요 없음, 한꺼번에 넘겨서 한꺼번에 받음
        - PAD를 가진 시퀸스를 효율적으로 병렬연산 하기 위해, pack, unpack으로 처리
        '''
        if isinstance(emb, tuple):
            # x = (batch_size, length, word_vec_size)
            # lengths = (batch_size) - 각 문장마다의 길이 (PAD 제외)
            x, lengths = emb
            x = pack(x, lengths.tolist(), batch_first=True)
        else:
            x = emb

        # y = (batch_size, length, hidden_size)
        # h[0] = (num_layers * 2, batch_size, hidden_size / 2)
        y, h = self.rnn(x)
        if isinstance(emb, tuple):
            y, _ = unpack(y, batch_first=True)

        return y, h
Beispiel #11
0
    def _get_lstm_features(self, token_ids, lengths):
        # |token_ids| = [batch_size, token_length]
        # |lengths| = [batch_size]

        embeds = self.word_embeds(token_ids)
        # |embeds| = [batch_size, token_length, hidden_dim]
        packed_embeds = pack(embeds,
                             lengths=lengths.tolist(),
                             batch_first=True,
                             enforce_sorted=False)
        # |embeds| = [batch_size, token_length, hidden_dim]

        # Apply RNN and get hiddens layers of each words
        last_hiddens, _ = self.rnn(packed_embeds)

        # Unpack ouput of rnn model
        last_hiddens, _ = unpack(last_hiddens, batch_first=True)
        # |last_hiddens| = [batch_size, max(token_length), hidden_size]
        lstm_feats = self.hidden2tag(self.tanh(last_hiddens))

        return lstm_feats
Beispiel #12
0
    def forward(self, input, lengths, cnn=True):
        embs = pack(self.embedding(input), lengths)
        outputs, state = self.rnn(embs)
        outputs = unpack(outputs)[0]
        if not self.config.bidirec:
            return outputs, state
        else:
            outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.
                                                                 hidden_size:]
            state = (state[0][1::2], state[1][1::2])
            o_ = outputs
            if cnn:
                outputs = outputs.transpose(0, 1).transpose(1, 2)
                outputs = self.selu1(self.conv1(outputs))
                outputs = self.selu2(self.conv2(outputs))
                outputs = self.selu3(self.conv3(outputs))
                conv = outputs.transpose(1, 2).transpose(0, 1)
                # outputs = self.sigmoid(outputs) * o_
                outputs = o_

            return outputs, conv, state
Beispiel #13
0
    def forward(self, inputs, lengths, hidden=None):
        """A pretrained MT-LSTM (McCann et. al. 2017).
        This LSTM was trained with 300d 840B GloVe on the WMT 2017 machine translation dataset.

        Arguments:
            inputs (Tensor): If MTLSTM handles embedding, a Long Tensor of size (batch_size, timesteps).
                             Otherwise, a Float Tensor of size (batch_size, timesteps, features).
            lengths (Long Tensor): (batch_size, lengths) lenghts of each sequence for handling padding
            hidden (Float Tensor): initial hidden state of the LSTM
        """
        if self.embed:
            inputs = self.vectors(inputs)
        lens, indices = torch.sort(lengths, 0, True)
        outputs, hidden_t = self.rnn(
            pack(inputs[indices], lens.tolist(), batch_first=True), hidden)
        outputs = unpack(outputs, batch_first=True)[0]
        _, _indices = torch.sort(indices, 0)
        outputs = outputs[_indices]
        if self.residual_embeddings:
            outputs = torch.cat([inputs, outputs], 2)
        return outputs
    def forward(self, tokens_emb, length):
        batch_size = tokens_emb.size(0)

        tokens_emb = pack(tokens_emb, length, batch_first=True)

        outputs, states_t = self.rnn(tokens_emb)
        reps, _ = unpack(outputs, batch_first=True)
        # print 'reps', reps

        size = reps.size()
        compressed_reps = reps.contiguous().view(
            -1, size[2])  # (batch_size x seq_len) * mem_size

        hbar = self.tanh(
            self.ws1(compressed_reps))  # (batch_size x seq_len) * attn_size
        alphas = self.ws2(hbar).view(size[0], size[1],
                                     -1)  # batch_size * seq_len * hops
        alphas = torch.transpose(alphas, 1,
                                 2).contiguous()  # batch_size * hops * seq_len

        mask = self.get_mask(length)
        # print 'mask', mask
        multi_mask = [mask.unsqueeze(1) for i in range(self.hops)]
        multi_mask = torch.cat(multi_mask, 1)
        # print 'multi_mask', multi_mask

        penalized_alphas = alphas + -1e7 * (1 - multi_mask)
        alphas = self.sm(penalized_alphas.view(
            -1, size[1]))  # (batch_size x hops) * seq_len
        alphas = alphas.view(size[0], self.hops,
                             size[1])  # batch_size * hops * seq_len
        # print 'alphas', alphas

        reps = torch.bmm(alphas, reps)  # batch_size * hops * hidden_size
        # here we use mean pooling of all hops
        rep = reps.mean(1)
        assert len(rep.size()) == 2

        # batch_size * classes, batch_size * hops * seq_len
        return self.dropout(rep), alphas
Beispiel #15
0
    def forward(self,
                enc_outs,
                target,
                source_rep_mask=None,
                target_length=None):
        lstm_in, _, lstm_states = self._prepare(enc_outs, target)

        # including EOE token
        # target_length += 1

        ### LSTM
        target_length, indices = torch.sort(target_length, 0, True)

        lstm_in_sorted = self.reorder_sequence(lstm_in, indices)
        lstm_in_packed = pack(lstm_in_sorted,
                              target_length.tolist(),
                              batch_first=True)

        # lstm_out: batch_size x num_target x hidden_units
        lstm_out_packed, _ = self.lstm(lstm_in_packed, lstm_states)
        lstm_out, _ = unpack(lstm_out_packed, batch_first=True)

        _, reverse_indices = torch.sort(indices, 0)
        lstm_out = self.reorder_sequence(lstm_out, reverse_indices)

        # lstm_out, _ = self.lstm(lstm_in, lstm_states)

        ### glimpse attention
        # glimpse: batch_size x num_target x hidden_units
        glimpse, _ = self.glimpse_attn(lstm_out,
                                       enc_outs,
                                       rep_mask=source_rep_mask)

        ### point attention
        # probs: batch_size x num_target x num_sentence
        _, logits = self.point_attn(glimpse,
                                    enc_outs,
                                    rep_mask=source_rep_mask)

        return logits
Beispiel #16
0
    def posteriorIndictedEmb(self,embs,posterior):
        #real alignment is sent in as list of index
        #variational relaxed posterior is sent in as MyPackedSequence

        #out   (batch x amr_len) x src_len x (dim+1)
        embs,src_len = unpack(embs)

        if isinstance(posterior,MyPackedSequence):
       #     print ("posterior is packed")
            posterior = myunpack(*posterior)
            embs = embs.transpose(0,1)
            out = []
            lengths = []
            amr_len = [len(p) for p in posterior]
            for i,emb in enumerate(embs):
                expanded_emb = emb.unsqueeze(0).expand([amr_len[i]]+[i for i in emb.size()]) # amr_len x src_len x dim
                indicator = posterior[i].unsqueeze(2)  # amr_len x src_len x 1
                out.append(torch.cat([expanded_emb,indicator],2))  # amr_len x src_len x (dim+1)
                lengths = lengths + [src_len[i]]*amr_len[i]
            data = torch.cat(out,dim=0)

            return pack(data, lengths, batch_first=True), amr_len

        elif isinstance(posterior,list):
            embs = embs.transpose(0,1)
            src_l = embs.size(1)
            amr_len = [len(i) for i in posterior]
            out = []
            lengths = []
            for i,emb in enumerate(embs):
                amr_l = len(posterior[i])
                expanded_emb = emb.unsqueeze(0).expand([amr_l]+[i for i in emb.size()]) # amr_len x src_len x dim
                indicator = emb.data.new(amr_l,src_l).zero_()
                indicator.scatter_(1, posterior[i].data.unsqueeze(1), 1.0) # amr_len x src_len x 1
                indicator = Variable(indicator.unsqueeze(2))
                out.append(torch.cat([expanded_emb,indicator],2))  # amr_len x src_len x (dim+1)
                lengths = lengths + [src_len[i]]*amr_l
            data = torch.cat(out,dim=0)

            return pack(data,lengths,batch_first=True),amr_len
Beispiel #17
0
    def enc(self, fbank, len=None):
        res = self.enc_fnn_lyr(fbank)
        for rnn, skip, drop in zip(self.enc_rnn_lyr, cf.enc_rnn_skip,
                                   cf.enc_rnn_drop):
            if len is not None:
                res = pack(res, len, batch_first=True)
            res, (h, c) = rnn(res)
            if len is not None:
                res, _ = unpack(res, batch_first=True)
                len = [x // skip for x in len]
            res = F.dropout(res[:, ::skip], drop, self.training)
            #res=F.normalize(res,dim=-1)

        res = self.enc_dec_conection(res)
        if rnn.bidirectional:
            c = self.enc_dec_conection(torch.cat((c[0], c[1]), -1))
            h = self.enc_dec_conection(torch.cat((h[0], h[1]), -1))
        else:
            c = self.enc_dec_conection(c[0])
            h = self.enc_dec_conection(h[1])

        return res, (h, c)
Beispiel #18
0
    def get_vectors(self, input, lengths=None):
        embed_input = self.embed(input)

        packed_emb = embed_input
        if lengths is not None:
            lengths = lengths.view(-1).tolist()
            packed_emb = nn.utils.rnn.pack_padded_sequence(
                embed_input, lengths)

        output, hidden = self.encoder(packed_emb)  # embed_input

        if lengths is not None:
            output = unpack(output)[0]

        # MUST apply negative mapping, so max pooling will not take padding elements
        batch_mask = self.create_mask(lengths)  # (time, batch_size)
        batch_mask = batch_mask.view(
            -1, len(lengths), 1)  # not sure if here broadcasting is right
        output = self.exp_mask(output,
                               batch_mask)  # now pads will never be chosen...

        return output
Beispiel #19
0
    def get_hypothesis_logits(self, words_hyp):
        mask_hyp = torch.ne(words_hyp, constants.PAD_ID)
        h_hyp = self.word_emb(words_hyp)
        lengths = mask_hyp.int().sum(dim=-1)
        h_hyp = pack(h_hyp, lengths, batch_first=True, enforce_sorted=False)
        h_hyp, hidden_hyp = self.rnn_hyp(h_hyp)
        h_hyp, _ = unpack(h_hyp, batch_first=True)

        if self.rnn_type == 'lstm':
            hidden_hyp = hidden_hyp[0]
        if self.is_bidir:
            hidden_states = [hidden_hyp[0], hidden_hyp[1]]
        else:
            hidden_states = [hidden_hyp[0]]
        hyp_logits = torch.cat(hidden_states, dim=-1).unsqueeze(1)

        # get last valid outputs instead of the last hidden state:
        # last_valid_idx = mask_hyp.int().sum(dim=-1) - 1
        # arange_vector = torch.arange(h_hyp.shape[0]).to(h_hyp.device)
        # hyp_logits = h_hyp[arange_vector, last_valid_idx].unsqueeze(1)

        return hyp_logits
Beispiel #20
0
    def encode_batch(self, inputs, trans, lengths):

        bsz, max_len = inputs.size()
        in_embs = self.word_embs(inputs)
        lens, indices = torch.sort(lengths, 0, True)

        # concat word embs with trans hid
        if self.use_input_parse:
            in_embs = torch.cat([in_embs, trans.unsqueeze(1).expand(bsz, max_len, self.d_trans)], 2)
        
        e_hid_init = self.e_hid_init.expand(2, bsz, self.d_hid).contiguous()
        e_cell_init = self.e_cell_init.expand(2, bsz, self.d_hid).contiguous()
        all_hids, (enc_last_hid, _) = self.encoder(pack(in_embs[indices], 
                    lens.tolist(), batch_first=True), (e_hid_init, e_cell_init))
        _, _indices = torch.sort(indices, 0)
        all_hids = unpack(all_hids, batch_first=True)[0][_indices]
        all_hids = self.encoder_proj(all_hids.view(-1, self.d_hid * 2)).view(bsz, max_len, self.d_hid)
        
        enc_last_hid = torch.cat([enc_last_hid[0], enc_last_hid[1]], 1)
        enc_last_hid = self.encoder_proj(enc_last_hid)[_indices]

        return all_hids, enc_last_hid
Beispiel #21
0
    def forward(self, input, heads, lengths=None, hidden=None):
        """ See EncoderBase.forward() for description of args and returns.
        inputs: [L, B, H], including the -ROOT-
        heads: [heads] * B
        """
        emb = self.dropout(input)

        packed_emb = emb
        if lengths is not None:
            # Lengths data is wrapped inside a Variable.
            packed_emb = pack(emb, lengths)

        outputs, hidden_t = self.rnn(packed_emb, hidden)

        if lengths is not None:
            outputs = unpack(outputs)[0]

        outputs = self.dropout(self.transform(outputs))
        max_length, batch_size, input_dim = outputs.size()
        trees = []
        indexes = np.full((max_length, batch_size), -1,
                          dtype=np.int32)  # a col is a sentence
        for b, head in enumerate(heads):
            root, tree = creatTree(
                head)  # head: a sentence's heads; sentence base
            root.traverse()  # traverse the tree
            for step, index in enumerate(root.order):
                indexes[step, b] = index
            trees.append(tree)

        dt_outputs, dt_hidden_ts = self.dt_tree.forward(
            outputs, indexes, trees)
        td_outputs, td_hidden_ts = self.td_tree.forward(
            outputs, indexes, trees)

        outputs = torch.cat([dt_outputs, td_outputs], dim=2).transpose(0, 1)
        output_t = torch.cat([dt_hidden_ts, td_hidden_ts], dim=1).unsqueeze(0)

        return outputs, output_t
Beispiel #22
0
    def enc(self,src,len=None):
        res=F.dropout(self.enc_emb_lyr(src),enc_emb_drop,self.training)
        h,c=None,None
        #import pdb; pdb.set_trace()
        for rnn,skip,drop in zip(self.enc_rnn_lyr,enc_rnn_skip,enc_rnn_drop):
            res = res if len is None else pack(res, len, batch_first=True)
            res,(h,c) =rnn(res)
            if len is not None:
            #    import pdb; pdb.set_trace()
                res,_=unpack(res, batch_first=True)
                len  =[x // skip for x in len]

            res = F.dropout(res[:,::skip],drop,self.training)
            #res=F.normalize(,dim=-1)
        res=self.enc_dec_conection(res)
        if rnn.bidirectional:
            c=self.enc_dec_conection(torch.cat((c[0],c[1]),-1))
            h=self.enc_dec_conection(torch.cat((h[0],h[1]),-1))
        else:
            c=self.enc_dec_conection(c[0])
            h=self.enc_dec_conection(h[1])
        return res,(h,c)
Beispiel #23
0
    def forward(self, x, lens, k, kx):
        # model takes as input the text, aspect, and location
        # runs BLSTM over text using embedding(location, aspect) as
        # the initial hidden state, as opposed to a different lstm for every pair???
        # output sentiment

        # DBG
        words = x

        emb = self.drop(self.lut(x))
        p_emb = pack(emb, lens, True)

        l, a = k
        N = x.shape[0]
        T = x.shape[1]
        y_idx = l * len(self.A) + a if self.L is not None else a
        s = (self.lut_la(y_idx)
            .view(N, 2, 2 * self.nlayers, self.rnn_sz)
            .permute(1, 2, 0, 3)
            .contiguous())
        state = (s[0], s[1])
        x, (h, c) = self.rnn(p_emb, state)
        # h: L * D x N x H
        x = unpack(x, True)[0]
        phi_s = self.proj_s(x)
        idxs = torch.arange(0, max(lens)).to(lens.device)
        # mask: N x R x 1
        mask = (idxs.repeat(len(lens), 1) >= lens.unsqueeze(-1))
        phi_s[:,:,-1].masked_fill_(1-mask, float("-inf"))
        phi_s[:,:,:3].masked_fill_(mask.unsqueeze(-1), float("-inf"))

        phi_y = torch.zeros(N, len(self.S)).to(self.psi_ys.device)
        psi_ys = torch.cat(
            [torch.diag(self.psi_ys), torch.zeros(len(self.S), 1).to(self.psi_ys)],
            dim=-1,
        ).expand(T, len(self.S), len(self.S)+1)
        Z, hy = ubersum("nts,tys,ny->n,ny", phi_s, psi_ys, phi_y, batch_dims="t", modulo_total=True)

        return hy
Beispiel #24
0
    def encode(self, input, hidden):
        '''
            input :: bs, sl

            return
                output :: bs, sl, nh*directions
                hidden :: n_layers*directions,bs, nh
        '''
        mask = torch.gt(input.data,0)
        input_length = torch.sum((mask.long()),dim=1)       # batch first = True, (batch, sl)
        lengths, indices = torch.sort(input_length, dim=0, descending=True)
        _, ind = torch.sort(indices, dim=0)
        input_length = torch.unbind(lengths, dim=0)
        embedded = self.embedding(torch.index_select(input,dim=0,index=Variable(indices)))
        output, hidden = self.gru(pack(embedded, input_length, batch_first=True), hidden)
        output = torch.index_select(unpack(output, batch_first=True)[0], dim=0,index=Variable(ind))*Variable(torch.unsqueeze(mask.float(),-1))
        hidden = torch.index_select(hidden[-1], dim=0, index=Variable(ind))
        #hidden = torch.unbind(hidden, dim=0)
        #hidden = torch.cat(hidden, 1)
        direction = 2 if self.bidirectional else 1
        assert hidden.size() == (input.size()[0],self.hidden_size) and output.size() == (input.size()[0], input.size()[1],self.hidden_size*direction)
        return output, hidden
Beispiel #25
0
    def forward(self, input, hidden=None):
        if isinstance(input, tuple):
            emb_ = self.word_lut(input[0])
            if self.src_fix_emb:
                emb = pack(self.emb2input(emb_), list(input[1]))
            else:
                emb = pack(emb_, list(input[1]))
        else:
            emb_ = self.word_lut(input)
            if self.src_fix_emb:
                emb = self.emb2input(emb_)
            else:
                emb = emb_

        # if isinstance(input, tuple):
        #     emb = pack(self.word_lut(input[0]), list(input[1]))
        # else:
        #     emb = self.word_lut(input)
        outputs, hidden_t = self.rnn(emb, hidden)
        if isinstance(input, tuple):
            outputs = unpack(outputs)[0]
        return hidden_t, outputs
Beispiel #26
0
    def forward(self, word, context, pool_type="max"):
        word_emb = self.embedding(word)
        context_emb = self.embedding(context)
        lengths = (context != constants.PAD_IDX).sum(dim=0).detach().cpu()

        # Sort by length (keep idx)
        context_len_sorted, idx_sort = np.sort(lengths.numpy())[::-1], np.argsort(-lengths.numpy())
        context_len_sorted = torch.from_numpy(context_len_sorted.copy())
        idx_unsort = np.argsort(idx_sort)
        context_emb = context_emb.index_select(1, torch.from_numpy(idx_sort).to(self.device))
        context_emb = pack(context_emb, context_len_sorted, batch_first=False)
        context_vec, _ = self.rnn(context_emb, None)
        context_vec = unpack(context_vec, batch_first=False)[0]

        # Un-sort by length
        context_vec = context_vec.index_select(1, torch.from_numpy(idx_unsort).to(self.device))
        # Pooling
        if pool_type == "mean":
            lengths = torch.FloatTensor(lengths.numpy().copy()).unsqueeze(1)
            emb = torch.sum(context_vec, 0).squeeze(0)
            if emb.ndimension() == 1:
                emb = emb.unsqueeze(0)
            emb = emb / lengths.expand_as(emb).to(self.device)
        elif pool_type == "max":
            emb = torch.max(context_vec, 0)[0]
            if emb.ndimension() == 3:
                emb = emb.squeeze(0)
                assert emb.ndimension() == 2
        V = F.relu(self.linear_q(word_emb.unsqueeze(1)))
        C = F.relu(self.linear_k(torch.transpose(context_vec, 0, 1)))
        T = torch.transpose(context_vec, 0, 1)

        scale = (C.size(-1)) ** -0.5
        att = torch.bmm(V, C.transpose(1, 2)) * scale
        att = self.softmax(att)
        att = self.dropout(att)
        c = torch.bmm(att, T)
        c = c.squeeze(1)
        return c, emb, att
Beispiel #27
0
    def forward(self, x, y, x_len=None, softmax=True):
        """
        Args:
            x(tensor): batch, frame, dim
            y(tensor): batch, frame
            x_len(tensor): batch
        """

        if self.pack_seq and x_len is not None:
            packed_x = pack(x, x_len, batch_first=True,
                            enforce_sorted=True)
            packed_x, _ = self.encoder(packed_x)
            x = unpack(packed_x, batch_first=True)[0]
        else:
            x = self.encoder(x)
        #prepend SOS, blk=0
        SOS = Variable(torch.zeros(y.shape[0], 1).long())
        SOS = SOS.cuda()
        y = torch.cat((SOS, y), dim=1)
        if self.decoder_type == 'rnn':
            y = self.embed(y)
            y, _ = self.decoder(y)
        else:
            y = self.decoder(y)
        T = x.size()[1]
        U = y.size()[1]
        #x: batch, T, U, dim
        #y: batch, T, U, dim
        x = x.unsqueeze(2).expand(-1, -1, U, -1)
        y = y.unsqueeze(1).expand(-1, T, -1, -1)
        #x_gate = F.glu(torch.cat((x, y), dim=-1), dim=-1)
        #y_gate = F.glu(torch.cat((y, x), dim=-1), dim=-1)
        #out = torch.cat((x_gate, y_gate), dim=-1)
        out = torch.cat((x, y), dim=-1)
        out = self.fc2(F.tanh(self.fc1(out)) * F.sigmoid(self.fc_gate(out)))
        #out = self.fc2(F.selu(self.fc1(out)))
        if softmax:
            out = F.log_softmax(out, dim=-1)
        return out
Beispiel #28
0
    def forward(self, input, lengths=None, hidden=None):
        packed_emb = input
        packed_emb = pack(input, lengths, enforce_sorted=False)

        outputs, hidden_t = self.rnn(packed_emb, hidden)
        outputs = unpack(outputs)[0]

        # consider both direction
        if self.bidirectional:
            if self.rnn_type == 'LSTM':
                h_n, c_n = hidden_t
                h_n = torch.cat([h_n[0:h_n.size(0):2], h_n[1:h_n.size(0):2]],
                                2)
                c_n = torch.cat([c_n[0:c_n.size(0):2], c_n[1:c_n.size(0):2]],
                                2)
                hidden_t = (h_n, c_n)
            else:
                hidden_t = torch.cat([
                    hidden_t[0:hidden_t.size(0):2],
                    hidden_t[1:hidden_t.size(0):2]
                ], 2)
        return outputs, hidden_t
Beispiel #29
0
def sinkhorn_score_regularizor(score):
    '''probBatch:   tuple (src_len x  batch x n_out,lengths),
       tgtBatch: amr_len x batch x n_feature  , lengths
        score = packed( amr_len x  batch x src_len , lengths)

      total_loss,total_data
    '''

    scores, lengths = unpack(score)
    S = 0
    r = opt.prior_t / opt.sink_t
    gamma_r = math.gamma(1 + r)
    for i, l in enumerate(lengths):
        # scores[:l, i, :l].data = torch.clamp(scores[:l, i, :l].data, 0, torch.max(scores[:l, i, :l].data))
        # print("scores", torch.max(scores[:l, i, :l].data), torch.min(scores[:l, i, :l].data))
        # aa = Variable(torch.randn(3, 5) * 50)
        # print(scores[:l, i, :l])
        scores[:l, i, :l] = torch.clamp(scores[:l, i, :l], min=-1)
        S = S + r / scores[:l, i, :l].sum() + gamma_r * torch.exp(
            -scores[:l, i, :l] * r).sum()

    return S  #+activation_loss
Beispiel #30
0
def apply_packed_sequence(rnn, embedding, lengths):
    """ Runs a forward pass of embeddings through an rnn using packed sequence.
    Args:
       rnn: The RNN that that we want to compute a forward pass with.
       embedding (FloatTensor b x seq x dim): A batch of sequence embeddings.
       lengths (LongTensor batch): The length of each sequence in the batch.
    Returns:
       output: The output of the RNN `rnn` with input `embedding`
    """
    # Sort Batch by sequence length
    lengths_sorted, permutation = torch.sort(lengths, descending=True)
    embedding_sorted = embedding[permutation]

    # Use Packed Sequence
    embedding_packed = pack(embedding_sorted, lengths_sorted, batch_first=True)
    outputs_packed, (hidden, cell) = rnn(embedding_packed)
    outputs_sorted, _ = unpack(outputs_packed, batch_first=True)
    # Restore original order
    _, permutation_rev = torch.sort(permutation, descending=False)
    outputs = outputs_sorted[permutation_rev]
    hidden, cell = hidden[:, permutation_rev], cell[:, permutation_rev]
    return outputs, (hidden, cell)
Beispiel #31
0
    def forward(self, input, hidden=None, is_fert=True):
        if isinstance(input, tuple):
            emb_ = self.word_lut(input[0])
            if self.src_fix_emb:
                emb_ = self.emb2input(emb_)
                emb = pack(emb_, list(input[1]))
            else:
                emb = pack(emb_, list(input[1]))
        else:
            emb_ = self.word_lut(input)
            if self.src_fix_emb:
                emb = self.emb2input(emb_)
            else:
                emb = emb_

        # if isinstance(input, tuple):
        #     emb = pack(self.word_lut(input[0]), list(input[1]))
        # else:
        #     emb = self.word_lut(input)
        outputs, hidden_t = self.rnn(emb, hidden)
        if isinstance(input, tuple):
            outputs = unpack(outputs)[0]

        cov = None
        if self.use_fert:
            if self.fert_mode == "emb":
                cov_inp = emb_
            elif self.fert_mode == "emh":
                cov_inp = torch.cat([emb_, outputs], -1)
            else:
                cov_inp = outputs
            if is_fert:
                cov = self.forward_cov(cov_inp)

        hidden_t = (self._fix_enc_hidden(hidden_t[0]),
                    self._fix_enc_hidden(hidden_t[1]))

        return hidden_t, outputs, cov, cov_inp
Beispiel #32
0
    def forward(self, inputs, hidden=None):
        if isinstance(inputs, PackedSequence):
            emb = PackedSequence(
                self.embedding_dropout(self.embedder(inputs.data)),
                inputs.batch_sizes)
            bsizes = inputs.batch_sizes.to(device=inputs.data.device)
            max_batch = int(bsizes[0])
            # Get padding mask
            time_dim = 1 if self.batch_first else 0
            range_batch = torch.arange(0,
                                       max_batch,
                                       dtype=bsizes.dtype,
                                       device=bsizes.device)
            range_batch = range_batch.unsqueeze(time_dim)
            bsizes = bsizes.unsqueeze(1 - time_dim)
            padding_mask = (bsizes - range_batch).le(0)
        else:
            padding_mask = inputs.eq(PAD)
            emb = self.embedding_dropout(self.embedder(inputs))
        outputs, hidden_t = self.rnn(emb, hidden)

        if isinstance(inputs, PackedSequence):
            outputs = unpack(outputs)[0]
        outputs = self.dropout(outputs)
        if hasattr(self, 'context_transform'):
            context = self.context_transform(outputs)
        else:
            context = None

        if hasattr(self, 'hidden_transform'):
            hidden_t = self.hidden_transform(hidden_t)

        state = State(outputs=outputs,
                      hidden=hidden_t,
                      context=context,
                      mask=padding_mask,
                      batch_first=self.batch_first)
        return state
Beispiel #33
0
    def forward(self, src, lengths=None):
        """See :func:`EncoderBase.forward()`"""
        self._check_args(src, lengths)

        emb = self.embeddings(src)
        # s_len, batch, emb_dim = emb.size()

        packed_emb = emb
        if lengths is not None and not self.no_pack_padded_seq:
            # Lengths data is wrapped inside a Tensor.
            lengths_list = lengths.view(-1).tolist()
            packed_emb = pack(emb, lengths_list)

        memory_bank, encoder_final = self.layers[0](packed_emb)

        if lengths is not None and not self.no_pack_padded_seq:
            memory_bank = unpack(memory_bank)[0]

        bottom_layers = 1
        if self.gnmt:
            memory_bank = self.dropout(memory_bank)
            memory_bank, enc_final = self.layers[1](memory_bank)
            #print(encoder_final[0].size())
            encoder_final0 = torch.cat((encoder_final[0], enc_final[0]), 0)
            encoder_final1 = torch.cat((encoder_final[1], enc_final[1]), 0)
            encoder_final = (encoder_final0, encoder_final1)
            bottom_layers = 2
        for i in range(bottom_layers, self.num_layers):
            residual = memory_bank
            memory_bank = self.dropout(memory_bank)
            memory_bank, enc_final = self.layers[i](memory_bank)
            encoder_final0 = torch.cat((encoder_final[0], enc_final[0]), 0)
            encoder_final1 = torch.cat((encoder_final[1], enc_final[1]), 0)
            encoder_final = (encoder_final0, encoder_final1)
            if self.num_layers >= 4:
                memory_bank = memory_bank + residual

        return encoder_final, memory_bank, lengths
Beispiel #34
0
    def forward(self, inputs, lengths):
        inputs = inputs.t()
        lengths = lengths.tolist()
        embs = pack(self.embedding(inputs), lengths)
        self.rnn.flatten_parameters()
        outputs, state = self.rnn(embs)
        outputs = unpack(outputs)[0]
        if self.bidirectional:
            outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:]  # Batch_size * Length * Hidden_size

        if self.inception:
            outputs = outputs.transpose(0, 1).transpose(1, 2)

            conv1 = self.sw1(outputs)
            conv3 = self.sw3(outputs)
            conv33 = self.sw33(outputs)
            conv = torch.cat((conv1, conv3, conv33), 1)

            conv = self.filter_linear(conv.transpose(1, 2))
            #conv = self.sw3(outputs).transpose(1, 2)
            outputs = outputs.transpose(1, 2).transpose(0, 1)   #seq_len, batch, dim
            outputs = outputs.transpose(0, 1)
            if self.encoding_gate:
                if self.gtu:
                    # conv =   "weight norm"
                    # outputs "weight norm"
                    gate = self.sigmoid(conv)
                    tan_conv = torch.tanh(outputs)
                    gtu_out = tan_conv * gate
                    return self.layer_normalization(gtu_out + outputs)
                else:
                    gate = self.sigmoid(conv)
                    outputs = outputs * gate
                    return outputs
            else:
                return conv
        else:
            return outputs.transpose(0, 1)
Beispiel #35
0
    def forward(self, inputs, lengths):
        embs = pack(self.embedding(inputs), lengths)
        outputs, state = self.rnn(embs)
        outputs = unpack(outputs)[0]
        if self.config.bidirectional:
            if self.config.swish:
                outputs = self.linear(outputs)
            else:
                outputs = outputs[:, :, :self.config.
                                  hidden_size] + outputs[:, :, self.config.
                                                         hidden_size:]
        if self.config.swish:
            outputs = outputs.transpose(0, 1).transpose(1, 2)
            conv1 = self.sw1(outputs)
            conv3 = self.sw3(outputs)
            conv33 = self.sw33(outputs)
            conv = torch.cat((conv1, conv3, conv33), 1)
            conv = self.filter_linear(conv.transpose(1, 2))
            if self.config.selfatt:
                conv = conv.transpose(0, 1)
                outputs = outputs.transpose(1, 2).transpose(0, 1)
            else:
                gate = self.sigmoid(conv)
                outputs = outputs * gate.transpose(1, 2)
                outputs = outputs.transpose(1, 2).transpose(0, 1)

        if self.config.selfatt:
            self.attention.init_context(context=conv)
            out_attn, weights = self.attention(conv, selfatt=True)
            gate = self.sigmoid(out_attn)
            outputs = outputs * gate

        if self.config.cell == 'gru':
            state = state[:self.config.dec_num_layers]
        else:
            state = (state[0][::2], state[1][::2])

        return outputs, state
Beispiel #36
0
    def encode(self, inputs, lengths, fr=0):
        bsz, max_len = inputs.size()
        e_hidden_init = self.e_hidden_init.expand(
            2, bsz, self.hidden_dim).contiguous()
        e_cell_init = self.e_cell_init.expand(2, bsz,
                                              self.hidden_dim).contiguous()
        lens, indices = torch.sort(lengths, 0, True)

        if fr and not self.share_vocab:
            in_embs = self.embedding_fr(inputs)
        else:
            in_embs = self.embedding(inputs)

        if fr and not self.share_encoder:
            if self.dropout > 0:
                in_embs = F.dropout(in_embs,
                                    p=self.dropout,
                                    training=self.training)
            all_hids, (enc_last_hid, _) = self.lstm_fr(
                pack(in_embs[indices], lens.tolist(), batch_first=True),
                (e_hidden_init, e_cell_init))
        else:
            if self.dropout > 0:
                in_embs = F.dropout(in_embs,
                                    p=self.dropout,
                                    training=self.training)
            all_hids, (enc_last_hid, _) = self.lstm(
                pack(in_embs[indices], lens.tolist(), batch_first=True),
                (e_hidden_init, e_cell_init))

        _, _indices = torch.sort(indices, 0)
        all_hids = unpack(all_hids, batch_first=True)[0][_indices]

        if self.pool == "max":
            embs = utils.max_pool(all_hids, lengths, self.gpu)
        elif self.pool == "mean":
            embs = utils.mean_pool(all_hids, lengths, self.gpu)
        return embs