Пример #1
0
    def __init__(self,
                 rnn_type,
                 h_size,
                 nlayers,
                 embedding,
                 slice_dim=100,
                 max_forget=0.875,
                 use_buffers=True,
                 dropouti=0,
                 dropouth=0,
                 wdrop=0,
                 context_type='slice'):

        super(RevEncoder, self).__init__()
        if rnn_type == 'revgru':
            rnn = RevGRU
            module_names = [
                'ih2_to_zr1', 'irh2_to_g1', 'ih1_to_zr2', 'irh1_to_g2'
            ]
        elif rnn_type == 'revlstm':
            rnn = RevLSTM
            module_names = ['ih2_to_zgfop1', 'ih1_to_zgfop2']

        self.rnns = [rnn(h_size, h_size, max_forget) for l in range(nlayers)]
        self.rnns = nn.ModuleList(self.rnns)

        self.lockdropi = RevLockedDropout(dropouti, h_size)
        self.lockdroph = RevLockedDropout(dropouth, h_size)

        self.rnn_type = rnn_type
        self.h_size = h_size
        self.nlayers = nlayers
        self.encoder = embedding

        self.slice_dim = slice_dim
        self.use_buffers = use_buffers
        self.dropouti = dropouti
        self.dropouth = dropouth
        self.wdrop = wdrop
        self.context_type = context_type

        if slice_dim == h_size:
            self.rev_forward = self.rev_forward_allh
            self.forward = self.rev_forward_allh
            self.reverse = self.reverse_allh
Пример #2
0
    def __init__(self, rnn_type, h_size, nlayers, embedding, attn_type='general',
        context_size=None, dropouti=0, dropouth=0, dropouto=0, wdrop=0, dropouts=0,
        max_forget=0.875, context_type='slice', slice_dim=0, use_buffers=True):

        super(RevDecoder, self).__init__()
        if rnn_type == 'revgru':
            rnn = RevGRU
            module_names = ['ih2_to_zr1', 'irh2_to_g1', 'ih1_to_zr2', 'irh1_to_g2']
        elif rnn_type == 'revlstm':
            rnn = RevLSTM
            module_names = ['ih2_to_zgfop1', 'ih1_to_zgfop2']

        self.rnns = [rnn(h_size, h_size, max_forget) for l in range(nlayers)]
        self.rnns = [RevWeightDrop(rnn, module_names, wdrop) for rnn in self.rnns]
        self.rnns = nn.ModuleList(self.rnns)

        self.lockdropi = RevLockedDropout(dropouti, h_size)
        self.lockdroph = RevLockedDropout(dropouth, h_size)
        self.lockdropo = RevLockedDropout(dropouto, h_size)
        self.lockdrops = RevLockedDropout(dropouts, slice_dim)

        # Basic attributes.
        self.rnn_type = rnn_type
        self.decoder_type = 'rnn'
        self.context_type = context_type
        self.context_size = context_size
        self.nlayers = nlayers
        self.h_size = h_size
        self.embedding = embedding
        self.wdrop = wdrop
        self.use_buffers = use_buffers
        self.generator = nn.Linear(h_size, embedding.num_embeddings)
        self.slice_dim = slice_dim

        # Set up the standard attention.
        self.attn_type = attn_type
        if attn_type != 'none':
            self.attn = onmt.modules.MultiSizeAttention(h_size, context_size=context_size,
                attn_type=attn_type)
Пример #3
0
class RevEncoder(nn.Module):
    def __init__(self,
                 rnn_type,
                 h_size,
                 nlayers,
                 embedding,
                 slice_dim=100,
                 max_forget=0.875,
                 use_buffers=True,
                 dropouti=0,
                 dropouth=0,
                 wdrop=0):

        super(RevEncoder, self).__init__()
        if rnn_type == 'revgru':
            rnn = RevGRU
            module_names = [
                'ih2_to_zr1', 'irh2_to_g1', 'ih1_to_zr2', 'irh1_to_g2'
            ]
        elif rnn_type == 'revlstm':
            rnn = RevLSTM
            module_names = ['ih2_to_zgfop1', 'ih1_to_zgfop2']

        self.rnns = [rnn(h_size, h_size, max_forget) for l in range(nlayers)]
        self.rnns = [
            RevWeightDrop(rnn, module_names, wdrop) for rnn in self.rnns
        ]
        self.rnns = nn.ModuleList(self.rnns)

        self.lockdropi = RevLockedDropout(dropouti, h_size)
        self.lockdroph = RevLockedDropout(dropouth, h_size)

        self.rnn_type = rnn_type
        self.h_size = h_size
        self.nlayers = nlayers
        self.encoder = embedding

        self.slice_dim = slice_dim
        self.use_buffers = use_buffers
        self.dropouti = dropouti
        self.dropouth = dropouth
        self.wdrop = wdrop

    def forward(self, input_seq, lengths=None, hiddens=None):
        """
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size)
            lengths (IntTensor): size is (batch_size,)
            hiddens (list): list of Tensors of length nlayers
        """
        # Set-up.
        seq_length, batch_size = input_seq.size()
        hiddens = self.init_hiddens(batch_size) if hiddens is None else hiddens
        max_length = len(
            input_seq) if lengths is None else lengths.max().item()
        self.set_masks(batch_size, input_seq.device)

        # Intialize information buffers. To limit unused space at the end of each buffer,
        # use the same information buffer for all hiddens of the same size.
        main_buf = InformationBuffer(batch_size, self.h_size // 2,
                                     input_seq.device)
        slice_buf = InformationBuffer(batch_size,
                                      self.h_size // 2 - self.slice_dim,
                                      input_seq.device)
        buffers = []
        for l in range(len(hiddens)):
            buf_h1 = slice_buf if l == self.nlayers - 1 else main_buf
            buf_h2 = buf_c1 = buf_c2 = main_buf
            if self.rnn_type == 'revlstm':
                buffers.append((buf_h1, buf_h2, buf_c1, buf_c2))
            else:
                buffers.append((buf_h1, buf_h2))

        # Initialize output dictionary.
        output_dict = {'optimal_bits': 0}
        output_dict['normal_bits'] = sum([
            32 * seq_length * batch_size * self.h_size
            for l in range(self.nlayers)
        ])
        if self.rnn_type == 'revlstm':
            output_dict['normal_bits'] *= 2

        # Find last hidden states of model.
        input_seq = self.lockdropi(self.encoder(input_seq))
        saved_hiddens = []
        with torch.no_grad():
            for t in range(len(input_seq)):
                mask = None if lengths is None else (t < lengths).int()
                curr_input = input_seq[t]

                for l, (rnn, buf,
                        hidden) in enumerate(zip(self.rnns, buffers, hiddens)):
                    if l == self.nlayers - 1:
                        saved_hiddens.append(hidden[:, :self.slice_dim])

                    next_hidden, stats = rnn(
                        curr_input, hidden, buf,
                        self.slice_dim if l == self.nlayers - 1 else 0, mask)

                    if l != self.nlayers - 1:
                        curr_input = self.lockdroph(
                            next_hidden['output_hidden'])

                    hiddens[l] = next_hidden['recurrent_hidden']
                    output_dict['optimal_bits'] += stats['optimal_bits']

            saved_hiddens.append(hiddens[-1][:, :self.slice_dim])

        output_dict['last_h'] = hiddens
        output_dict['used_bits'] = main_buf.bit_usage() + slice_buf.bit_usage() +\
            32*self.slice_dim*batch_size*max_length

        return hiddens, saved_hiddens, buffers, main_buf, output_dict

    def init_hiddens(self, batch_size):
        weight = next(self.parameters())
        if self.rnn_type == 'revlstm':
            return [
                weight.new(torch.zeros(batch_size,
                                       2 * self.h_size)).zero_().int()
                for l in range(self.nlayers)
            ]
        else:
            return [
                weight.new(torch.zeros(batch_size, self.h_size)).zero_().int()
                for l in range(self.nlayers)
            ]

    def set_masks(self, batch_size, device='cuda'):
        self.lockdropi.set_mask(batch_size, device=device)
        self.lockdroph.set_mask(batch_size, device=device)
        for rnn in self.rnns:
            rnn.set_mask()

    def reverse(self, input_seq, lengths, last_hiddens, last_hidden_grads,
                saved_hiddens, saved_hidden_grads, buffers):
        """
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size, 1)
            lengths (IntTensor): size is (batch_size,)
            last_hiddens (list): list of IntTensors (each with size (batch_size, h_size)) 
                of length nlayers
            last_hidden_grads (list): list of FloatTensors (each with size (batch_size, h_size))
                of length nlayers
            saved_hiddens (list): list of IntTensors (each with size (batch_size, slice_dim))
                of length seq_length + 1
            saved_hidden_grads (list): list of FloatTensors (each with size (batch_size, slice_dim))
                of length seq_length + 1
            buffers (list): list of InformationBuffers of length nlayers        
        """
        hiddens = last_hiddens
        hidden_grads = last_hidden_grads

        # TODO(mmackay): replace saved_hidden_grads with just use .grad attribute of the
        # hiddens in saved_hiddens
        for t in reversed(range(len(input_seq))):
            mask = None if lengths is None else (t < lengths).int()
            for l in reversed(range(self.nlayers)):
                rnn, buf, hidden = self.rnns[l], buffers[l], hiddens[l]
                # Reconstruct previous hidden state.
                with torch.no_grad():
                    if l != 0:
                        curr_input = hiddens[l - 1]
                        drop_input = self.lockdroph(
                            ConvertToFloat.apply(curr_input[:, :self.h_size],
                                                 hidden_radix))
                    else:
                        curr_input = input_seq[t]
                        drop_input = self.lockdropi(self.encoder(curr_input))

                    prev_hidden = rnn.reverse(
                        drop_input, hidden, buf,
                        self.slice_dim if l == self.nlayers - 1 else 0,
                        saved_hiddens[t] if l == self.nlayers - 1 else None,
                        mask)

                # Rerun forwards pass from previous hidden to hidden at time t to construct
                # computation graph and compute gradients.
                prev_hidden.requires_grad_()
                if l != 0:
                    curr_input.requires_grad_()
                    drop_input = self.lockdroph(
                        ConvertToFloat.apply(curr_input[:, :self.h_size],
                                             hidden_radix))
                else:
                    drop_input = self.lockdropi(self.encoder(curr_input))

                curr_hidden, _ = rnn(drop_input, prev_hidden, mask=mask)
                curr_hidden_grad = hidden_grads[l]
                if l == self.nlayers - 1:
                    curr_hidden_grad[:, :self.slice_dim] += saved_hidden_grads[
                        t + 1]
                torch.autograd.backward(curr_hidden['recurrent_hidden'],
                                        grad_tensors=curr_hidden_grad)
                hiddens[l] = prev_hidden.detach()
                hidden_grads[l] = prev_hidden.grad.data
                if l != 0:
                    hidden_grads[l - 1] += curr_input.grad.data

    def test_forward(self, input_seq, lengths=None, hiddens=None):
        """
        Used for testing correctness of gradients in reverse computation.
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size, 1)
            lengths (IntTensor): size is (batch_size,)
            hiddens (list): list of Tensors of length nlayers
        """
        # Set-up. We don't set masks. It is assumed we will call forward before
        # this method, which will set the masks which should remain the same in this method
        # to ensure gradients are equal.
        seq_length, batch_size = input_seq.size()
        hiddens = self.init_hiddens(batch_size) if hiddens is None else hiddens
        max_length = len(
            input_seq) if lengths is None else lengths.max().item()
        # self.set_masks(batch_size, input_seq.device)

        # Intialize information buffers. To limit unused space at the end of each buffer,
        # use the same information buffer for all hiddens of the same size.
        main_buf = InformationBuffer(batch_size, self.h_size // 2,
                                     input_seq.device)
        slice_buf = InformationBuffer(batch_size,
                                      self.h_size // 2 - self.slice_dim,
                                      input_seq.device)
        buffers = []
        for l in range(len(hiddens)):
            buf_h1 = slice_buf if l == self.nlayers - 1 else main_buf
            buf_h2 = buf_c1 = buf_c2 = main_buf
            if self.rnn_type == 'revlstm':
                buffers.append((buf_h1, buf_h2, buf_c1, buf_c2))
            else:
                buffers.append((buf_h1, buf_h2))

        # Find last hidden states of model.
        input_seq = self.lockdropi(self.encoder(input_seq))
        saved_hiddens = []
        for t in range(len(input_seq)):
            mask = None if lengths is None else (t < lengths).int()
            curr_input = input_seq[t]

            for l, (rnn, buf,
                    hidden) in enumerate(zip(self.rnns, buffers, hiddens)):
                if l == self.nlayers - 1:
                    saved_hiddens.append(hidden[:, :self.slice_dim])

                next_hidden, stats = rnn(
                    curr_input, hidden, buf,
                    self.slice_dim if l == self.nlayers - 1 else 0, mask)

                if l != self.nlayers - 1:
                    curr_input = self.lockdroph(next_hidden['output_hidden'])

                hiddens[l] = next_hidden['recurrent_hidden']

        saved_hiddens.append(hiddens[-1][:, :self.slice_dim])

        return hiddens, saved_hiddens, buffers, {}
Пример #4
0
class RevDecoder(nn.Module):

    def __init__(self, rnn_type, h_size, nlayers, embedding, attn_type='general',
        context_size=None, dropouti=0, dropouth=0, dropouto=0, wdrop=0, dropouts=0,
        max_forget=0.875, context_type='slice', slice_dim=0, use_buffers=True):

        super(RevDecoder, self).__init__()
        if rnn_type == 'revgru':
            rnn = RevGRU
            module_names = ['ih2_to_zr1', 'irh2_to_g1', 'ih1_to_zr2', 'irh1_to_g2']
        elif rnn_type == 'revlstm':
            rnn = RevLSTM
            module_names = ['ih2_to_zgfop1', 'ih1_to_zgfop2']
        
        self.rnns = [rnn(h_size, h_size, max_forget) for l in range(nlayers)]
        self.rnns = [RevWeightDrop(rnn, module_names, wdrop) for rnn in self.rnns]
        self.rnns = nn.ModuleList(self.rnns)

        self.lockdropi = RevLockedDropout(dropouti, h_size)
        self.lockdroph = RevLockedDropout(dropouth, h_size)
        self.lockdropo = RevLockedDropout(dropouto, h_size)
        self.lockdrops = RevLockedDropout(dropouts, slice_dim)

        # Basic attributes.
        self.rnn_type = rnn_type
        self.decoder_type = 'rnn' 
        self.context_type = context_type
        self.context_size = context_size
        self.nlayers = nlayers
        self.h_size = h_size
        self.embedding = embedding
        self.wdrop = wdrop
        self.use_buffers = use_buffers
        self.generator = nn.Linear(h_size, embedding.num_embeddings)

        # Set up the standard attention.
        self.attn_type = attn_type
        if attn_type != 'none':
            self.attn = onmt.modules.MultiSizeAttention(
                h_size, context_size=context_size, attn_type=attn_type)


    def forward(self, input_seq, hiddens, main_buf, dec_lengths=None):
        """
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size)
            hiddens (list): list of IntTensors (each with size (batch_size, h_size))
                of length nlayers
            main_buf (InformationBuffer): storage for hidden states with size
                (batch_size, h_size)
            dec_lengths (IntTensor): size is (batch_size,)
        """
        # Set up.
        seq_length, batch_size = input_seq.size()
        self.set_masks(batch_size, input_seq.device)

        # Intialize information buffers. To limit unused space at the end of each buffer, 
        # use the same information buffer for all hiddens of the same size.
        # This means using the buffer 'main_buf' from the encoder.      
        buffers = []
        for l in range(len(hiddens)):
            if self.rnn_type == 'revlstm':
                buffers.append((main_buf, main_buf, main_buf, main_buf))
            else:
                buffers.append((main_buf, main_buf))

        # Initialize output dictionary.
        output_dict = {'optimal_bits': 0}
        output_dict['normal_bits'] = sum(
            [32*seq_length*batch_size*self.h_size for l in range(self.nlayers)])
        if self.rnn_type == 'revlstm':
            output_dict['normal_bits'] *= 2


        output_dict['hid_seq'] = []
        for l in range(self.nlayers):
            output_dict['hid_seq'].append([hiddens[l].data.clone().numpy()])

        # Find last hidden states of model.
        input_seq = self.lockdropi(self.embedding(input_seq))
        with torch.no_grad():
            for t in range(len(input_seq)):
                mask = None if dec_lengths is None else (t < dec_lengths).int() 
                curr_input = input_seq[t]
                
                for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)):                   
                    next_hidden, stats = rnn(curr_input, hidden, buf, slice_dim=0,
                        mask=mask)
                    if l != self.nlayers-1:
                        curr_input = self.lockdroph(next_hidden['output_hidden'])

                    hiddens[l] = next_hidden['recurrent_hidden']
                    output_dict['hid_seq'][l].append(hiddens[l].data.clone().numpy())
                    output_dict['optimal_bits'] += stats['optimal_bits']

        output_dict['last_h'] = hiddens
        output_dict['used_bits'] = float('inf') # TODO(mmackay): figure out right way to compute bits
        return hiddens, buffers, output_dict

    def reverse(self, input_seq, target_seq, last_hiddens, saved_hiddens, buffers,
        token_weights, dec_lengths, enc_lengths, enc_input_seq, rev_enc):
        """
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size)
            hiddens (list): list of IntTensors (each with size (batch_size, h_size))
                of length nlayers
            main_buf (InformationBuffer): storage for hidden states with size
                (batch_size, h_size)
            dec_lengths (IntTensor): size is (batch_size,)
        """

        hiddens = last_hiddens
        hidden_grads = [next(self.parameters()).new_zeros(h.size()) for h in hiddens]
        loss_fun = lambda output, target: F.cross_entropy(
            output, target, weight=token_weights, size_average=False)

        saved_hiddens = [hidden.requires_grad_() for hidden in saved_hiddens]
        total_loss = 0.

        output_dict = {'hid_seq': []}
        for l in range(self.nlayers):
            output_dict['hid_seq'].append([last_hiddens[l].data.clone().numpy()])
        output_dict['drop_hs'] = []

        for t in reversed(range(len(input_seq))):
            top_hidden = hiddens[-1].requires_grad_()
            top_hidden_ = ConvertToFloat.apply(top_hidden[:,:self.h_size], hidden_radix)
            top_hidden_ = self.lockdropo(top_hidden_).unsqueeze(0)

            output_dict['drop_hs'].append(top_hidden_.data.clone().numpy())

            context = self.construct_context(saved_hiddens, enc_input_seq, rev_enc)            
            attn_hidden, _ = self.attn(top_hidden_.transpose(0, 1).contiguous(),
                context.transpose(0, 1), context_lengths=enc_lengths)
            output = self.generator(attn_hidden[0])
            last_loss = loss_fun(output, target_seq[t])
            last_loss.backward()
            hidden_grads[-1] += top_hidden.grad
            total_loss += last_loss

            mask = None if dec_lengths is None else (t < dec_lengths).int() 
            for l in reversed(range(self.nlayers)):
                rnn, buf, hidden = self.rnns[l], buffers[l], hiddens[l]
                # Reconstruct previous hidden state.
                with torch.no_grad():
                    if l != 0:
                        curr_input = hiddens[l-1]
                        drop_input = self.lockdroph(ConvertToFloat.apply(
                            curr_input[:, :self.h_size], hidden_radix))
                    else:
                        curr_input = input_seq[t].squeeze()
                        drop_input = self.lockdropi(self.embedding(curr_input))
                    
                    prev_hidden = rnn.reverse(drop_input, hidden, buf, slice_dim=0,
                        saved_hidden=None, mask=mask)

                # Rerun forwards pass from previous hidden to hidden at time t to construct
                # computation graph and compute gradients.
                prev_hidden.requires_grad_()
                if l != 0:
                    curr_input.requires_grad_()
                    drop_input = self.lockdroph(ConvertToFloat.apply(
                        curr_input[:, :self.h_size], hidden_radix))
                else:
                    drop_input = self.lockdropi(self.embedding(curr_input))

                curr_hidden, _ = rnn(drop_input, prev_hidden, mask=mask)
                torch.autograd.backward(
                    curr_hidden['recurrent_hidden'], grad_tensors=hidden_grads[l])
                hiddens[l] = prev_hidden.detach()
                hidden_grads[l] = prev_hidden.grad.data
                if l != 0:
                    hidden_grads[l-1] += curr_input.grad.data

                output_dict['hid_seq'][l].append(prev_hidden.data.clone().numpy())

        return hiddens, hidden_grads, saved_hiddens, buffers, total_loss, output_dict

    def construct_context(self, saved_hiddens, enc_input_seq, rev_enc):
        if self.context_type == 'emb':
            context = rev_enc.lockdropi(rev_enc.encoder(enc_input_seq))
        elif self.context_type == 'slice':
            context = self.lockdrops(ConvertToFloat.apply(
                torch.stack(saved_hiddens[1:]), hidden_radix))
        elif self.context_type == 'slice_emb':
            embs = rev_enc.lockdropi(rev_enc.encoder(enc_input_seq))
            slices = self.lockdrops(ConvertToFloat.apply(
                torch.stack(saved_hiddens[1:]), hidden_radix))
            context = torch.cat([embs, slices], dim=2)

        return context

    def test_forward(self, input_seq, hiddens, main_buf, dec_lengths=None):
        """
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size)
            hiddens (list): list of IntTensors (each with size (batch_size, h_size))
                of length nlayers
            main_buf (InformationBuffer): storage for hidden states with size
                (batch_size, h_size)
            dec_lengths (IntTensor): size is (batch_size,)
        """
        # Set up.
        seq_length, batch_size = input_seq.size()
        # self.set_masks(batch_size, input_seq.device)

        # Intialize information buffers. To limit unused space at the end of each buffer, 
        # use the same information buffer for all hiddens of the same size.
        # This means using the buffer 'main_buf' from the encoder.      
        buffers = []
        for l in range(len(hiddens)):
            if self.rnn_type == 'revlstm':
                buffers.append((main_buf, main_buf, main_buf, main_buf))
            else:
                buffers.append((main_buf, main_buf))

        output_dict = {'hid_seq': []}
        for l in range(self.nlayers):
            output_dict['hid_seq'].append([hiddens[l].data.clone().numpy()])

        # Find last hidden states of model.
        input_seq = self.lockdropi(self.embedding(input_seq))
        top_hiddens = []
        for t in range(len(input_seq)):
            mask = None if dec_lengths is None else (t < dec_lengths).int() 
            curr_input = input_seq[t]
            
            for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)):                   
                next_hidden, stats = rnn(curr_input, hidden, buf, slice_dim=0,
                    mask=mask)
                if l != self.nlayers-1:
                    curr_input = self.lockdroph(next_hidden['output_hidden'])

                hiddens[l] = next_hidden['recurrent_hidden']
                output_dict['hid_seq'][l].append(hiddens[l].data.clone().numpy())

                if l == self.nlayers-1:
                    top_hiddens.append(next_hidden['output_hidden'])

        return top_hiddens, hiddens, output_dict

    def test_compute_loss(self, top_hiddens, target_seq, saved_hiddens, token_weights, enc_lengths,
        enc_input_seq, rev_enc):
        """
        Used to test correctness of gradients in reverse computation.
        """
        top_hiddens = self.lockdropo(torch.stack(top_hiddens))

        output_dict = {'drop_hs': []}
        for t in range(len(top_hiddens)):
            output_dict['drop_hs'].append(top_hiddens[t].data.clone().numpy())

        if self.context_type == 'emb':
            context = rev_enc.lockdropi(rev_enc.encoder(enc_input_seq))
        elif self.context_type == 'slice':
            context = self.lockdrops(ConvertToFloat.apply(
                torch.stack(saved_hiddens[1:]), hidden_radix))
        elif self.context_type == 'slice_emb':
            embs = rev_enc.lockdropi(rev_enc.encoder(enc_input_seq))
            slices = self.lockdrops(ConvertToFloat.apply(
                torch.stack(saved_hiddens[1:]), hidden_radix))
            context = torch.cat([embs, slices], dim=2)
        
        attn_hiddens, _ = self.attn(top_hiddens.transpose(0, 1).contiguous(), 
            context.transpose(0, 1), context_lengths=enc_lengths)

        output = self.generator(attn_hiddens.view(-1, attn_hiddens.size(2)))
        loss = F.cross_entropy(output, target_seq.view(-1), weight=token_weights,
            size_average=False)

        return loss, output_dict



    @property
    def _input_size(self):
        """
        Private helper returning the number of expected features.
        """
        return self.embeddings.embedding_dim  # IMPORTANT: This is where the issues are

    def set_masks(self, batch_size, device='cuda'):
        self.lockdropi.set_mask(batch_size, device=device)
        self.lockdroph.set_mask(batch_size, device=device)
        self.lockdropo.set_mask(batch_size, device=device)
        self.lockdrops.set_mask(batch_size, device=device)
        for rnn in self.rnns:
            rnn.set_mask()

    def init_hiddens(self, batch_size):
        ''' For testing purposes, should never be called in practice'''
        weight = next(self.parameters())
        if self.rnn_type == 'revlstm':
            return [weight.new(torch.zeros(batch_size, 2 * self.h_size)).zero_().int()
                for l in range(self.nlayers)]
        else:
            return [weight.new(torch.zeros(batch_size, self.h_size)).zero_().int() 
                for l in range(self.nlayers)]
Пример #5
0
class RevEncoder(nn.Module):
    def __init__(self,
                 rnn_type,
                 h_size,
                 nlayers,
                 embedding,
                 slice_dim=100,
                 max_forget=0.875,
                 use_buffers=True,
                 dropouti=0,
                 dropouth=0,
                 wdrop=0,
                 context_type='slice'):

        super(RevEncoder, self).__init__()
        if rnn_type == 'revgru':
            rnn = RevGRU
            module_names = [
                'ih2_to_zr1', 'irh2_to_g1', 'ih1_to_zr2', 'irh1_to_g2'
            ]
        elif rnn_type == 'revlstm':
            rnn = RevLSTM
            module_names = ['ih2_to_zgfop1', 'ih1_to_zgfop2']

        self.rnns = [rnn(h_size, h_size, max_forget) for l in range(nlayers)]
        self.rnns = nn.ModuleList(self.rnns)

        self.lockdropi = RevLockedDropout(dropouti, h_size)
        self.lockdroph = RevLockedDropout(dropouth, h_size)

        self.rnn_type = rnn_type
        self.h_size = h_size
        self.nlayers = nlayers
        self.encoder = embedding

        self.slice_dim = slice_dim
        self.use_buffers = use_buffers
        self.dropouti = dropouti
        self.dropouth = dropouth
        self.wdrop = wdrop
        self.context_type = context_type

        if slice_dim == h_size:
            self.rev_forward = self.rev_forward_allh
            self.forward = self.rev_forward_allh
            self.reverse = self.reverse_allh

    def forward_test(self, input_seq, lengths=None, hiddens=None):
        """
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size)
            lengths (IntTensor): size is (batch_size,)
            hiddens (list): list of Tensors of length nlayers
        """
        # Set-up.
        seq_length, batch_size = input_seq.size()
        hiddens = self.init_hiddens(batch_size) if hiddens is None else hiddens
        max_length = len(
            input_seq) if lengths is None else lengths.max().item()
        self.set_masks(batch_size, input_seq.device)

        # Find last hidden states of model.
        input_seq = self.lockdropi(self.encoder(input_seq))
        all_hiddens = []
        with torch.set_grad_enabled(self.training):
            for t in range(len(input_seq)):
                mask = None if lengths is None else (t < lengths).int()
                curr_input = input_seq[t]

                for l, (rnn, hidden) in enumerate(zip(self.rnns, hiddens)):
                    if l == self.nlayers - 1:
                        all_hiddens.append(hidden)

                    next_hidden, stats = rnn(
                        curr_input,
                        hidden,
                        buf=None,
                        slice_dim=self.slice_dim if l == self.nlayers -
                        1 else 0,
                        mask=mask)

                    if l != self.nlayers - 1:
                        curr_input = self.lockdroph(
                            next_hidden['output_hidden'])

                    hiddens[l] = next_hidden['recurrent_hidden']

            all_hiddens.append(hiddens[-1])

            hidden_context = ConvertToFloat.apply(torch.stack(all_hiddens[1:]),
                                                  hidden_radix)

            if self.context_type == 'hidden':
                context = hidden_context
            elif self.context_type == 'emb':
                context = input_seq
            elif self.context_type == 'slice':
                context = hidden_context[:, :, :self.slice_dim]
            elif self.context_type == 'slice_emb':
                context = torch.cat(
                    [input_seq, hidden_context[:, :, :self.slice_dim]], dim=2)

        return hiddens, context

    def forward(self, input_seq, lengths=None, hiddens=None):
        """
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size)
            lengths (IntTensor): size is (batch_size,)
            hiddens (list): list of Tensors of length nlayers
        """
        # Set-up.
        seq_length, batch_size = input_seq.size()
        hiddens = self.init_hiddens(batch_size) if hiddens is None else hiddens
        max_length = len(
            input_seq) if lengths is None else lengths.max().item()
        self.set_masks(batch_size, input_seq.device)

        # Intialize information buffers. To limit unused space at the end of each buffer,
        # use the same information buffer for all hiddens of the same size.
        main_buf = InformationBuffer(batch_size, self.h_size // 2,
                                     input_seq.device)
        slice_buf = InformationBuffer(batch_size,
                                      self.h_size // 2 - self.slice_dim,
                                      input_seq.device)
        buffers = []
        for l in range(len(hiddens)):
            if self.training and self.use_buffers:
                buf_h1 = slice_buf if l == self.nlayers - 1 else main_buf
                buf_h2 = buf_c1 = buf_c2 = main_buf
            else:
                buf_h1 = buf_h2 = buf_c1 = buf_c2 = None
            if self.rnn_type == 'revlstm':
                buffers.append((buf_h1, buf_h2, buf_c1, buf_c2))
            else:
                buffers.append((buf_h1, buf_h2))

        # Initialize output dictionary.
        output_dict = {'optimal_bits': 0}
        output_dict['normal_bits'] = sum([
            32 * seq_length * batch_size * self.h_size
            for l in range(self.nlayers)
        ])
        if self.rnn_type == 'revlstm':
            output_dict['normal_bits'] *= 2

        # Find last hidden states of model.
        input_seq = self.lockdropi(self.encoder(input_seq))
        saved_hiddens = []
        with torch.set_grad_enabled(self.training):
            for t in range(len(input_seq)):
                mask = None if lengths is None else (t < lengths).int()
                curr_input = input_seq[t]

                for l, (rnn, buf,
                        hidden) in enumerate(zip(self.rnns, buffers, hiddens)):
                    if l == self.nlayers - 1:
                        saved_hiddens.append(hidden[:, :self.slice_dim])

                    next_hidden, stats = rnn(
                        curr_input, hidden, buf,
                        self.slice_dim if l == self.nlayers - 1 else 0, mask)

                    if l != self.nlayers - 1:
                        curr_input = self.lockdroph(
                            next_hidden['output_hidden'])

                    hiddens[l] = next_hidden['recurrent_hidden']
                    output_dict['optimal_bits'] += stats['optimal_bits']

            saved_hiddens.append(hiddens[-1][:, :self.slice_dim])

        return hiddens, saved_hiddens, buffers, main_buf, slice_buf, output_dict

    def rev_forward(self, input_seq, lengths=None, hiddens=None):
        """
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size)
            lengths (IntTensor): size is (batch_size,)
            hiddens (list): list of Tensors of length nlayers
        """
        # Set-up.
        seq_length, batch_size = input_seq.size()
        hiddens = self.init_hiddens(batch_size) if hiddens is None else hiddens
        max_length = len(
            input_seq) if lengths is None else lengths.max().item()
        self.set_masks(batch_size, input_seq.device)

        # Intialize information buffers. To limit unused space at the end of each buffer,
        # use the same information buffer for all hiddens of the same size.
        main_buf = InformationBuffer(batch_size, self.h_size // 2,
                                     input_seq.device)
        slice_buf = InformationBuffer(batch_size,
                                      self.h_size // 2 - self.slice_dim,
                                      input_seq.device)
        buffers = []
        for l in range(len(hiddens)):
            buf_h1 = slice_buf if l == self.nlayers - 1 else main_buf
            buf_h2 = buf_c1 = buf_c2 = main_buf
            if self.rnn_type == 'revlstm':
                buffers.append((buf_h1, buf_h2, buf_c1, buf_c2))
            else:
                buffers.append((buf_h1, buf_h2))

        # Initialize output dictionary.
        output_dict = {'optimal_bits': 0}
        output_dict['normal_bits'] = sum([
            32 * seq_length * batch_size * self.h_size
            for l in range(self.nlayers)
        ])
        if self.rnn_type == 'revlstm':
            output_dict['normal_bits'] *= 2

        # Find last hidden states of model.
        input_seq = self.lockdropi(self.encoder(input_seq))
        saved_hiddens = []
        with torch.no_grad():
            for t in range(len(input_seq)):
                mask = None if lengths is None else (t < lengths).int()
                curr_input = input_seq[t]

                for l, (rnn, buf,
                        hidden) in enumerate(zip(self.rnns, buffers, hiddens)):
                    if l == self.nlayers - 1 and self.slice_dim > 0:
                        saved_hiddens.append(hidden[:, :self.slice_dim])

                    next_hidden, stats = rnn(
                        curr_input, hidden, buf,
                        self.slice_dim if l == self.nlayers - 1 else 0, mask)

                    if l != self.nlayers - 1:
                        curr_input = self.lockdroph(
                            next_hidden['output_hidden'])

                    hiddens[l] = next_hidden['recurrent_hidden']
                    output_dict['optimal_bits'] += stats['optimal_bits']
            if self.slice_dim > 0:
                saved_hiddens.append(hiddens[-1][:, :self.slice_dim])

        output_dict['last_h'] = hiddens
        output_dict['used_bits'] = main_buf.bit_usage() + slice_buf.bit_usage() +\
            32*self.slice_dim*batch_size*max_length

        return hiddens, saved_hiddens, buffers, main_buf, slice_buf, output_dict

    def rev_forward_allh(self, input_seq, lengths=None, hiddens=None):
        """
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size)
            lengths (IntTensor): size is (batch_size,)
            hiddens (list): list of Tensors of length nlayers
        """
        # Set-up.
        seq_length, batch_size = input_seq.size()
        hiddens = self.init_hiddens(batch_size) if hiddens is None else hiddens
        max_length = len(
            input_seq) if lengths is None else lengths.max().item()
        self.set_masks(batch_size, input_seq.device)

        # Intialize information buffers. To limit unused space at the end of each buffer,
        # use the same information buffer for all hiddens of the same size.
        main_buf = InformationBuffer(batch_size, self.h_size // 2,
                                     input_seq.device)
        buffers = []
        for l in range(len(hiddens)):
            if l == self.nlayers - 1:
                buffers.append(None)
            else:
                buf_h1 = buf_h2 = buf_c1 = buf_c2 = main_buf
                if self.rnn_type == 'revlstm':
                    buffers.append((buf_h1, buf_h2, buf_c1, buf_c2))
                else:
                    buffers.append((buf_h1, buf_h2))

        # Find last hidden states of model.
        input_seq = self.lockdropi(self.encoder(input_seq))
        saved_hiddens = []
        with torch.no_grad():
            for t in range(len(input_seq)):
                mask = None if lengths is None else (t < lengths).int()
                curr_input = input_seq[t]

                for l, (rnn, buf,
                        hidden) in enumerate(zip(self.rnns, buffers, hiddens)):
                    if l == self.nlayers - 1:
                        saved_hiddens.append(hidden.detach().clone())

                    next_hidden, stats = rnn(curr_input, hidden, buf, 0, mask)
                    if l != self.nlayers - 1:
                        curr_input = self.lockdroph(
                            next_hidden['output_hidden'])
                    hiddens[l] = next_hidden['recurrent_hidden']

            saved_hiddens.append(hiddens[-1])

        total_bits = sum([
            32 * seq_length * batch_size * self.h_size
            for l in range(self.nlayers)
        ])
        output_dict = {
            'normal_bits': total_bits,
            'optimal_bits': total_bits,
            'used_bits': total_bits
        }
        if self.rnn_type == 'revlstm':
            for k in ['normal_bits', 'optimal_bits', 'used_bits']:
                output_dict[k] *= 2
        output_dict['last_h'] = hiddens

        return hiddens, saved_hiddens, buffers, main_buf, None, output_dict

    def init_hiddens(self, batch_size):
        weight = next(self.parameters())
        if self.rnn_type == 'revlstm':
            return [
                weight.new(batch_size, 2 * self.h_size).zero_().int()
                for l in range(self.nlayers)
            ]
        else:
            return [
                weight.new(batch_size, self.h_size).zero_().int()
                for l in range(self.nlayers)
            ]

    def set_masks(self, batch_size, device='cuda'):
        self.lockdropi.set_mask(batch_size, device=device)
        self.lockdroph.set_mask(batch_size, device=device)

    def reverse(self, input_seq, lengths, last_hiddens, last_hidden_grads,
                saved_hiddens, buffers):
        """
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size, 1)
            lengths (IntTensor): size is (batch_size,)
            last_hiddens (list): list of IntTensors (each with size (batch_size, h_size))
                of length nlayers
            last_hidden_grads (list): list of FloatTensors (each with size (batch_size, h_size))
                of length nlayers
            saved_hiddens (list): list of IntTensors (each with size (batch_size, slice_dim))
                of length seq_length + 1
            buffers (list): list of InformationBuffers of length nlayers
        """
        hiddens = last_hiddens
        hidden_grads = last_hidden_grads

        for t in reversed(range(len(input_seq))):
            mask = None if lengths is None else (t < lengths).int()
            for l in reversed(range(self.nlayers)):
                rnn, buf, hidden = self.rnns[l], buffers[l], hiddens[l]
                # Reconstruct previous hidden state.
                with torch.no_grad():
                    if l != 0:
                        curr_input = hiddens[l - 1]
                        drop_input = self.lockdroph(
                            ConvertToFloat.apply(curr_input[:, :self.h_size],
                                                 hidden_radix))
                    else:
                        curr_input = input_seq[t]
                        drop_input = self.lockdropi(self.encoder(curr_input))

                    prev_hidden = rnn.reverse(
                        drop_input, hidden, buf,
                        self.slice_dim if l == self.nlayers - 1 else 0,
                        saved_hiddens[t] if l == self.nlayers - 1
                        and self.slice_dim > 0 else None, mask)

                # Rerun forwards pass from previous hidden to hidden at time t to construct
                # computation graph and compute gradients.
                prev_hidden.requires_grad_()
                if l != 0:
                    curr_input.requires_grad_()
                    drop_input = self.lockdroph(
                        ConvertToFloat.apply(curr_input[:, :self.h_size],
                                             hidden_radix))
                else:
                    drop_input = self.lockdropi(self.encoder(curr_input))

                curr_hidden, _ = rnn(drop_input, prev_hidden, mask=mask)
                curr_hidden_grad = hidden_grads[l]
                if l == self.nlayers - 1 and self.slice_dim > 0:
                    curr_hidden_grad[:, :self.slice_dim] += saved_hiddens[
                        t + 1].grad
                torch.autograd.backward(curr_hidden['recurrent_hidden'],
                                        grad_tensors=curr_hidden_grad)
                hiddens[l] = prev_hidden.detach()
                hidden_grads[l] = prev_hidden.grad.data
                if l != 0:
                    hidden_grads[l - 1] += curr_input.grad.data

    def reverse_allh(self, input_seq, lengths, last_hiddens, last_hidden_grads,
                     saved_hiddens, buffers):
        """
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size, 1)
            lengths (IntTensor): size is (batch_size,)
            last_hiddens (list): list of IntTensors (each with size (batch_size, h_size))
                of length nlayers
            last_hidden_grads (list): list of FloatTensors (each with size (batch_size, h_size))
                of length nlayers
            saved_hiddens (list): list of IntTensors (each with size (batch_size, slice_dim))
                of length seq_length + 1
            buffers (list): list of InformationBuffers of length nlayers
        """
        hiddens = last_hiddens
        hidden_grads = last_hidden_grads

        for t in reversed(range(len(input_seq))):
            mask = None if lengths is None else (t < lengths).int()
            for l in reversed(range(self.nlayers)):
                rnn, buf, hidden = self.rnns[l], buffers[l], hiddens[l]
                # Reconstruct previous hidden state.
                with torch.no_grad():
                    if l != 0:
                        curr_input = hiddens[l - 1]
                        drop_input = self.lockdroph(
                            ConvertToFloat.apply(curr_input[:, :self.h_size],
                                                 hidden_radix))
                    else:
                        curr_input = input_seq[t]
                        drop_input = self.lockdropi(self.encoder(curr_input))

                    prev_hidden = rnn.reverse(
                        drop_input, hidden, buf, 0,
                        saved_hiddens[t] if l == self.nlayers - 1 else None,
                        mask)

                # Rerun forwards pass from previous hidden to hidden at time t to construct
                # computation graph and compute gradients.
                prev_hidden.requires_grad_()
                if l != 0:
                    curr_input.requires_grad_()
                    drop_input = self.lockdroph(
                        ConvertToFloat.apply(curr_input[:, :self.h_size],
                                             hidden_radix))
                else:
                    drop_input = self.lockdropi(self.encoder(curr_input))

                curr_hidden, _ = rnn(drop_input, prev_hidden, mask=mask)
                curr_hidden_grad = hidden_grads[l]
                if l == self.nlayers - 1:
                    curr_hidden_grad += saved_hiddens[t + 1].grad
                torch.autograd.backward(curr_hidden['recurrent_hidden'],
                                        grad_tensors=curr_hidden_grad)
                hiddens[l] = prev_hidden.detach()
                hidden_grads[l] = prev_hidden.grad.data
                if l != 0:
                    hidden_grads[l - 1] += curr_input.grad.data
Пример #6
0
class RevDecoder(nn.Module):

    def __init__(self, rnn_type, h_size, nlayers, embedding, attn_type='general',
        context_size=None, dropouti=0, dropouth=0, dropouto=0, wdrop=0, dropouts=0,
        max_forget=0.875, context_type='slice', slice_dim=0, use_buffers=True):

        super(RevDecoder, self).__init__()
        if rnn_type == 'revgru':
            rnn = RevGRU
            module_names = ['ih2_to_zr1', 'irh2_to_g1', 'ih1_to_zr2', 'irh1_to_g2']
        elif rnn_type == 'revlstm':
            rnn = RevLSTM
            module_names = ['ih2_to_zgfop1', 'ih1_to_zgfop2']

        self.rnns = [rnn(h_size, h_size, max_forget) for l in range(nlayers)]
        self.rnns = [RevWeightDrop(rnn, module_names, wdrop) for rnn in self.rnns]
        self.rnns = nn.ModuleList(self.rnns)

        self.lockdropi = RevLockedDropout(dropouti, h_size)
        self.lockdroph = RevLockedDropout(dropouth, h_size)
        self.lockdropo = RevLockedDropout(dropouto, h_size)
        self.lockdrops = RevLockedDropout(dropouts, slice_dim)

        # Basic attributes.
        self.rnn_type = rnn_type
        self.decoder_type = 'rnn'
        self.context_type = context_type
        self.context_size = context_size
        self.nlayers = nlayers
        self.h_size = h_size
        self.embedding = embedding
        self.wdrop = wdrop
        self.use_buffers = use_buffers
        self.generator = nn.Linear(h_size, embedding.num_embeddings)
        self.slice_dim = slice_dim

        # Set up the standard attention.
        self.attn_type = attn_type
        if attn_type != 'none':
            self.attn = onmt.modules.MultiSizeAttention(h_size, context_size=context_size,
                attn_type=attn_type)

    def forward_test(self, input_seq, hiddens, context, enc_lengths, dec_lengths=None):
        # Set up.
        seq_length, batch_size = input_seq.size()
        self.set_masks(batch_size, input_seq.device)

        # Find last hidden states of model.
        top_hiddens = []
        input_seq = self.lockdropi(self.embedding(input_seq))
        with torch.set_grad_enabled(self.training):
            for t in range(len(input_seq)):
                mask = None if dec_lengths is None else (t < dec_lengths).int()
                curr_input = input_seq[t]

                for l, (rnn, hidden) in enumerate(zip(self.rnns, hiddens)):
                    next_hidden, stats = rnn(curr_input, hidden, buf=None, slice_dim=0, mask=mask)
                    if l != self.nlayers-1:
                        curr_input = self.lockdroph(next_hidden['output_hidden'])
                    else:
                        top_hiddens.append(next_hidden['output_hidden'])

                    hiddens[l] = next_hidden['recurrent_hidden']

        top_hiddens = self.lockdropo(torch.stack(top_hiddens))
        attn_hiddens, attn_scores = self.attn(top_hiddens.transpose(0, 1).contiguous(),
            context.transpose(0, 1), context_lengths=enc_lengths)

        attns = {"std": attn_scores}
        output = self.generator(attn_hiddens.view(-1, attn_hiddens.size(2)))
        return output, hiddens, attns


    def forward(self, input_seq, target_seq, hiddens, saved_hiddens, main_buf,
                token_weights, padding_idx, dec_lengths, enc_lengths, enc_input_seq, rev_enc):
        # Set up.
        seq_length, batch_size = input_seq.size()
        self.set_masks(batch_size, input_seq.device)

        total_loss = 0.
        num_words = 0.0
        num_correct = 0.0

        # Intialize information buffers. To limit unused space at the end of each buffer,
        # use the same information buffer for all hiddens of the same size.
        # This means using the buffer 'main_buf' from the encoder.
        buffers = []
        if main_buf is None and self.use_buffers:
            main_buf = InformationBuffer(batch_size, self.h_size//2, input_seq.device)
        for l in range(len(hiddens)):
            if self.rnn_type == 'revlstm':
                buffers.append((main_buf, main_buf, main_buf, main_buf)) # Note: main_buf can be None
            else:
                buffers.append((main_buf, main_buf))

        # Initialize output dictionary.
        output_dict = {'optimal_bits': 0}
        output_dict['normal_bits'] = sum([32*seq_length*batch_size*self.h_size for l in range(self.nlayers)])
        if self.rnn_type == 'revlstm':
            output_dict['normal_bits'] *= 2

        # Find last hidden states of model.
        top_hiddens = []
        input_seq = self.lockdropi(self.embedding(input_seq))
        with torch.set_grad_enabled(self.training):
            for t in range(len(input_seq)):
                mask = None if dec_lengths is None else (t < dec_lengths).int()
                curr_input = input_seq[t]

                for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)):
                    next_hidden, stats = rnn(curr_input, hidden, buf, slice_dim=0, mask=mask)
                    if l != self.nlayers-1:
                        curr_input = self.lockdroph(next_hidden['output_hidden'])
                    else:
                        top_hidden_ = next_hidden['output_hidden']

                    hiddens[l] = next_hidden['recurrent_hidden']
                    output_dict['optimal_bits'] += stats['optimal_bits']

                top_hidden_ = self.lockdropo(top_hidden_).unsqueeze(0)

                context = self.construct_context(saved_hiddens, enc_input_seq, rev_enc)
                attn_hidden, _ = self.attn(top_hidden_.transpose(0, 1).contiguous(),
                    context.transpose(0, 1), context_lengths=enc_lengths)
                output = self.generator(attn_hidden[0])

                loss = F.cross_entropy(output, target_seq[t], weight=token_weights, size_average=False)
                total_loss += loss

                non_padding = target_seq[t].ne(padding_idx)
                pred = F.log_softmax(output).max(1)[1]
                num_words += non_padding.sum().item()
                num_correct += pred.eq(target_seq[t]).masked_select(non_padding).sum().item()

        attns = {}
        output_dict['last_h'] = hiddens
        output_dict['used_bits'] = float('inf')
        return total_loss, num_words, num_correct, attns, main_buf, output_dict


    def rev_forward(self, input_seq, hiddens, main_buf, dec_lengths=None):
        """
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size)
            hiddens (list): list of IntTensors (each with size (batch_size, h_size))
                of length nlayers
            main_buf (InformationBuffer): storage for hidden states with size
                (batch_size, h_size)
            dec_lengths (IntTensor): size is (batch_size,)
        """
        # Set up.
        seq_length, batch_size = input_seq.size()
        self.set_masks(batch_size, input_seq.device)

        # Intialize information buffers. To limit unused space at the end of each buffer,
        # use the same information buffer for all hiddens of the same size.
        # This means using the buffer 'main_buf' from the encoder.
        buffers = []
        if main_buf is None:
            main_buf = InformationBuffer(batch_size, self.h_size//2, input_seq.device)
        for l in range(len(hiddens)):
            if self.rnn_type == 'revlstm':
                buffers.append((main_buf, main_buf, main_buf, main_buf))
            else:
                buffers.append((main_buf, main_buf))

        # Initialize output dictionary.
        output_dict = {'optimal_bits': 0}
        output_dict['normal_bits'] = sum(
            [32*seq_length*batch_size*self.h_size for l in range(self.nlayers)])
        if self.rnn_type == 'revlstm':
            output_dict['normal_bits'] *= 2

        # Find last hidden states of model.
        input_seq = self.lockdropi(self.embedding(input_seq))
        with torch.no_grad():
            for t in range(len(input_seq)):
                mask = None if dec_lengths is None else (t < dec_lengths).int()
                curr_input = input_seq[t]

                for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)):
                    next_hidden, stats = rnn(curr_input, hidden, buf, slice_dim=0,
                        mask=mask)
                    if l != self.nlayers-1:
                        curr_input = self.lockdroph(next_hidden['output_hidden'])

                    hiddens[l] = next_hidden['recurrent_hidden']
                    output_dict['optimal_bits'] += stats['optimal_bits']

        output_dict['last_h'] = hiddens
        output_dict['used_bits'] = float('inf') # TODO(mmackay): figure out right way to compute bits
        return hiddens, buffers, main_buf, output_dict


    def reverse(self, input_seq, target_seq, last_hiddens, saved_hiddens, buffers,
                token_weights, padding_idx, dec_lengths, enc_lengths, enc_input_seq, rev_enc):
        """
        Arguments:
            input_seq (LongTensor): size is (dec_seq_length, batch_size)
            target_seq (IntTensor): size is (dec_seq_length, batch_size)
            last_hiddens (list): list of IntTensors (each with size (batch_size, h_size))
                of length nlayers
            saved_hiddens (list): list of IntTensors (each with size (batch_size, slice_dim))
                of length enc_seq_length
            buffers (list): list of InformationBuffers of length nlayers
            token_weights (FloatTensor): of size (dec_ntokens,)
            dec_lengths (IntTensor): size is (batch_size,)
            enc_lengths (IntTensor): size is (batch_size,)
            enc_input_seq (LongTensor): size is (enc_seq_length, batch_size)
            rev_enc (RevEncoder): RevEncoder module used before the RevDecoder
        """

        batch_size = float(enc_lengths.shape[0])

        hiddens = last_hiddens
        hidden_grads = [next(self.parameters()).new_zeros(h.size()) for h in hiddens]
        loss_fun = lambda output, target: F.cross_entropy(output, target, weight=token_weights, size_average=False)

        saved_hiddens = [hidden.requires_grad_() for hidden in saved_hiddens]
        total_loss = 0.
        num_words = 0.0
        num_correct = 0.0

        for t in reversed(range(len(input_seq))):
            top_hidden = hiddens[-1].requires_grad_()
            top_hidden_ = ConvertToFloat.apply(top_hidden[:,:self.h_size], hidden_radix)
            top_hidden_ = self.lockdropo(top_hidden_).unsqueeze(0)

            context = self.construct_context(saved_hiddens, enc_input_seq, rev_enc)
            attn_hidden, _ = self.attn(top_hidden_.transpose(0, 1).contiguous(),
                context.transpose(0, 1), context_lengths=enc_lengths)
            output = self.generator(attn_hidden[0])

            last_loss = loss_fun(output, target_seq[t])
            last_loss.div(batch_size).backward()

            hidden_grads[-1] += top_hidden.grad
            total_loss += last_loss

            non_padding = target_seq[t].ne(padding_idx)
            pred = F.log_softmax(output).max(1)[1]
            num_words += non_padding.sum().item()
            num_correct += pred.eq(target_seq[t]).masked_select(non_padding).sum().item()

            mask = None if dec_lengths is None else (t < dec_lengths).int()
            for l in reversed(range(self.nlayers)):
                rnn, buf, hidden = self.rnns[l], buffers[l], hiddens[l]
                # Reconstruct previous hidden state.
                with torch.no_grad():
                    if l != 0:
                        curr_input = hiddens[l-1]
                        drop_input = self.lockdroph(ConvertToFloat.apply(
                            curr_input[:, :self.h_size], hidden_radix))
                    else:
                        curr_input = input_seq[t].squeeze()
                        drop_input = self.lockdropi(self.embedding(curr_input))

                    prev_hidden = rnn.reverse(drop_input, hidden, buf, slice_dim=0,
                        saved_hidden=None, mask=mask)

                # Rerun forwards pass from previous hidden to hidden at time t to construct
                # computation graph and compute gradients.
                prev_hidden.requires_grad_()
                if l != 0:
                    curr_input.requires_grad_()
                    drop_input = self.lockdroph(ConvertToFloat.apply(
                        curr_input[:, :self.h_size], hidden_radix))
                else:
                    drop_input = self.lockdropi(self.embedding(curr_input))

                curr_hidden, _ = rnn(drop_input, prev_hidden, mask=mask)
                torch.autograd.backward(
                    curr_hidden['recurrent_hidden'], grad_tensors=hidden_grads[l])
                hiddens[l] = prev_hidden.detach()
                hidden_grads[l] = prev_hidden.grad.data
                if l != 0:
                    hidden_grads[l-1] += curr_input.grad.data

        return total_loss, num_words, num_correct, hiddens, hidden_grads, saved_hiddens

    def construct_context(self, saved_hiddens, enc_input_seq, rev_enc):
        if self.context_type == 'emb':
            context = rev_enc.lockdropi(rev_enc.encoder(enc_input_seq))
        elif self.context_type == 'slice':
            context = ConvertToFloat.apply(torch.stack(saved_hiddens[1:]), hidden_radix)
            context = context[:,:,:self.slice_dim]
            context = self.lockdrops(context)
        elif self.context_type == 'slice_emb':
            embs = rev_enc.lockdropi(rev_enc.encoder(enc_input_seq))
            slices = ConvertToFloat.apply(torch.stack(saved_hiddens[1:]), hidden_radix)
            slices = slices[:,:,:self.slice_dim]
            slices = self.lockdrops(slices)
            context = torch.cat([embs, slices], dim=2)
        return context

    def set_masks(self, batch_size, device='cuda'):
        self.lockdropi.set_mask(batch_size, device=device)
        self.lockdroph.set_mask(batch_size, device=device)
        self.lockdropo.set_mask(batch_size, device=device)
        self.lockdrops.set_mask(batch_size, device=device)
        for rnn in self.rnns:
            rnn.set_mask()

    @property
    def _input_size(self):
        """
        Private helper returning the number of expected features.
        """
        return self.embeddings.embedding_dim