def test_grads():
        hiddens = rev_enc.init_hiddens(batch_size)
        enc_hiddens, saved_hiddens, buffers, main_buf, _ = rev_enc(enc_input_seq, enc_lengths, hiddens)
        main_buf = InformationBuffer(batch_size, h_size//2, 'cpu')
        
        hiddens, buffers, output_dict = rev_dec(dec_input_seq, enc_hiddens, main_buf, dec_lengths)
        hiddens, hidden_grads, saved_hiddens, buffers, rev_loss, rev_dict = rev_dec.reverse(dec_input_seq,
            target_seq, hiddens, saved_hiddens, buffers, token_weights, dec_lengths,
            enc_lengths, enc_input_seq, rev_enc)
        rev_grads = create_grad_dict(rev_dec)
        if context_type in ['slice', 'slice_emb']:
            saved_hidden_rev_grads = [hid.grad.data.clone().numpy() for hid in saved_hiddens[1:]]

        rev_dec.zero_grad()
        hiddens = rev_enc.init_hiddens(batch_size)
        enc_hiddens, saved_hiddens, buffers, main_buf, _ = rev_enc(enc_input_seq, enc_lengths, hiddens)
        main_buf = InformationBuffer(batch_size, h_size//2, 'cpu')
        
        saved_hiddens = [hid.requires_grad_() for hid in saved_hiddens]
        top_hiddens, last_hiddens, normal_dict = rev_dec.test_forward(dec_input_seq, enc_hiddens, main_buf, dec_lengths)
        nor_loss, normal_dict = rev_dec.test_compute_loss(top_hiddens, target_seq, saved_hiddens,
            token_weights, enc_lengths, enc_input_seq, rev_enc)
        nor_loss.backward()
        nor_grads = create_grad_dict(rev_dec)
        if context_type in ['slice', 'slice_emb']:
            saved_hidden_nor_grads = [hid.grad.data.clone().numpy() for hid in saved_hiddens[1:]]

        compare_grads(rev_grads, nor_grads)

        if context_type in ['slice', 'slice_emb']:
            for rev_grad, nor_grad in zip(saved_hidden_rev_grads, saved_hidden_nor_grads):
                print(angle_between(rev_grad.flatten(), nor_grad.flatten()))
    def test_forward(self, input_seq, lengths=None, hiddens=None):
        """
        Used for testing correctness of gradients in reverse computation.
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size, 1)
            lengths (IntTensor): size is (batch_size,)
            hiddens (list): list of Tensors of length nlayers
        """
        # Set-up. We don't set masks. It is assumed we will call forward before
        # this method, which will set the masks which should remain the same in this method
        # to ensure gradients are equal.
        seq_length, batch_size = input_seq.size()
        hiddens = self.init_hiddens(batch_size) if hiddens is None else hiddens
        max_length = len(
            input_seq) if lengths is None else lengths.max().item()
        # self.set_masks(batch_size, input_seq.device)

        # Intialize information buffers. To limit unused space at the end of each buffer,
        # use the same information buffer for all hiddens of the same size.
        main_buf = InformationBuffer(batch_size, self.h_size // 2,
                                     input_seq.device)
        slice_buf = InformationBuffer(batch_size,
                                      self.h_size // 2 - self.slice_dim,
                                      input_seq.device)
        buffers = []
        for l in range(len(hiddens)):
            buf_h1 = slice_buf if l == self.nlayers - 1 else main_buf
            buf_h2 = buf_c1 = buf_c2 = main_buf
            if self.rnn_type == 'revlstm':
                buffers.append((buf_h1, buf_h2, buf_c1, buf_c2))
            else:
                buffers.append((buf_h1, buf_h2))

        # Find last hidden states of model.
        input_seq = self.lockdropi(self.encoder(input_seq))
        saved_hiddens = []
        for t in range(len(input_seq)):
            mask = None if lengths is None else (t < lengths).int()
            curr_input = input_seq[t]

            for l, (rnn, buf,
                    hidden) in enumerate(zip(self.rnns, buffers, hiddens)):
                if l == self.nlayers - 1:
                    saved_hiddens.append(hidden[:, :self.slice_dim])

                next_hidden, stats = rnn(
                    curr_input, hidden, buf,
                    self.slice_dim if l == self.nlayers - 1 else 0, mask)

                if l != self.nlayers - 1:
                    curr_input = self.lockdroph(next_hidden['output_hidden'])

                hiddens[l] = next_hidden['recurrent_hidden']

        saved_hiddens.append(hiddens[-1][:, :self.slice_dim])

        return hiddens, saved_hiddens, buffers, {}
    def test_construction():
        hiddens = rev_enc.init_hiddens(batch_size)
        enc_hiddens, saved_hiddens, buffers, main_buf, _ = rev_enc(enc_input_seq, enc_lengths, hiddens)

        main_buf = InformationBuffer(batch_size, h_size//2, 'cpu')
        hiddens, buffers, output_dict = rev_dec(dec_input_seq, enc_hiddens, main_buf, dec_lengths)

        hiddens = rev_enc.init_hiddens(batch_size)
        enc_hiddens, saved_hiddens, buffers, main_buf, _ = rev_enc(enc_input_seq, enc_lengths, hiddens)
        main_buf = InformationBuffer(batch_size, h_size//2, 'cpu')
        top_hiddens, last_hiddens, normal_dict = rev_dec.test_forward(dec_input_seq, enc_hiddens, main_buf, dec_lengths)

        for l in range(nlayers):
            print("LAYER: " + str(l))
            for for_h, nor_h in zip(output_dict['hid_seq'][l], normal_dict['hid_seq'][l]):
                print((for_h == nor_h).all())
    def test_reconstruction():
        '''
        Roadmap:
            Confirmed: hiddens constructed in forward and hiddens reconstructed in reverse are the same
            Confirmed: hiddens constructed in forward and test_forward are the same
            
            Confirm hiddens being passed to attention module are the same b/t reverse and forward
            Confirm attn_hiddens are the same b/t reverse and test_compute_loss
            Confirm loss value is the same b/t reverse and test_compute_loss
            Confirm gradients are the same

            Try changing token weights and ensure gradients remain the same
        '''
        main_buf = InformationBuffer(batch_size, h_size//2, 'cpu')
        hiddens, buffers, output_dict = rev_dec(dec_input_seq, enc_hiddens, main_buf, dec_lengths)

        hiddens, hidden_grads, saved_hiddens, buffers, loss, rev_dict = rev_dec.reverse(dec_input_seq,
            target_seq, hiddens, saved_hiddens, buffers, token_weights, dec_lengths,
            enc_lengths, enc_input_seq, rev_enc)

        for l in range(nlayers):
            for for_h, rev_h in zip(output_dict['hid_seq'][l], reversed(rev_dict['hid_seq'][l])):
                print((for_h == rev_h).all())
                print(for_h)
                print(rev_h)
    def test_loss():
        hiddens = rev_enc.init_hiddens(batch_size)
        enc_hiddens, saved_hiddens, buffers, main_buf, _ = rev_enc(enc_input_seq, enc_lengths, hiddens)
        main_buf = InformationBuffer(batch_size, h_size//2, 'cpu')
        
        hiddens, buffers, output_dict = rev_dec(dec_input_seq, enc_hiddens, main_buf, dec_lengths)
        hiddens, hidden_grads, saved_hiddens, buffers, rev_loss, rev_dict = rev_dec.reverse(dec_input_seq,
            target_seq, hiddens, saved_hiddens, buffers, token_weights, dec_lengths,
            enc_lengths, enc_input_seq, rev_enc)

        hiddens = rev_enc.init_hiddens(batch_size)
        enc_hiddens, saved_hiddens, buffers, main_buf, _ = rev_enc(enc_input_seq, enc_lengths, hiddens)
        main_buf = InformationBuffer(batch_size, h_size//2, 'cpu')
        
        top_hiddens, last_hiddens, normal_dict = rev_dec.test_forward(dec_input_seq, enc_hiddens, main_buf, dec_lengths)
        nor_loss, normal_dict = rev_dec.test_compute_loss(top_hiddens, target_seq, saved_hiddens,
            token_weights, enc_lengths, enc_input_seq, rev_enc)

        print(rev_loss.data.numpy())
        print(nor_loss.data.numpy())
Example #6
0
    def rev_forward(self, input_seq, hiddens, main_buf, dec_lengths=None):
        """
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size)
            hiddens (list): list of IntTensors (each with size (batch_size, h_size))
                of length nlayers
            main_buf (InformationBuffer): storage for hidden states with size
                (batch_size, h_size)
            dec_lengths (IntTensor): size is (batch_size,)
        """
        # Set up.
        seq_length, batch_size = input_seq.size()
        self.set_masks(batch_size, input_seq.device)

        # Intialize information buffers. To limit unused space at the end of each buffer,
        # use the same information buffer for all hiddens of the same size.
        # This means using the buffer 'main_buf' from the encoder.
        buffers = []
        if main_buf is None:
            main_buf = InformationBuffer(batch_size, self.h_size//2, input_seq.device)
        for l in range(len(hiddens)):
            if self.rnn_type == 'revlstm':
                buffers.append((main_buf, main_buf, main_buf, main_buf))
            else:
                buffers.append((main_buf, main_buf))

        # Initialize output dictionary.
        output_dict = {'optimal_bits': 0}
        output_dict['normal_bits'] = sum(
            [32*seq_length*batch_size*self.h_size for l in range(self.nlayers)])
        if self.rnn_type == 'revlstm':
            output_dict['normal_bits'] *= 2

        # Find last hidden states of model.
        input_seq = self.lockdropi(self.embedding(input_seq))
        with torch.no_grad():
            for t in range(len(input_seq)):
                mask = None if dec_lengths is None else (t < dec_lengths).int()
                curr_input = input_seq[t]

                for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)):
                    next_hidden, stats = rnn(curr_input, hidden, buf, slice_dim=0,
                        mask=mask)
                    if l != self.nlayers-1:
                        curr_input = self.lockdroph(next_hidden['output_hidden'])

                    hiddens[l] = next_hidden['recurrent_hidden']
                    output_dict['optimal_bits'] += stats['optimal_bits']

        output_dict['last_h'] = hiddens
        output_dict['used_bits'] = float('inf') # TODO(mmackay): figure out right way to compute bits
        return hiddens, buffers, main_buf, output_dict
    def forward(self, input_seq, lengths=None, hiddens=None):
        """
        Arguments:
            input_seq (LongTensor): size is (seq_length, batch_size)
            lengths (IntTensor): size is (batch_size,)
            hiddens (list): list of Tensors of length nlayers
        """
        # Set-up.
        seq_length, batch_size = input_seq.size()
        hiddens = self.init_hiddens(batch_size) if hiddens is None else hiddens
        max_length = len(
            input_seq) if lengths is None else lengths.max().item()
        self.set_masks(batch_size, input_seq.device)

        # Intialize information buffers. To limit unused space at the end of each buffer,
        # use the same information buffer for all hiddens of the same size.
        main_buf = InformationBuffer(batch_size, self.h_size // 2,
                                     input_seq.device)
        slice_buf = InformationBuffer(batch_size,
                                      self.h_size // 2 - self.slice_dim,
                                      input_seq.device)
        buffers = []
        for l in range(len(hiddens)):
            buf_h1 = slice_buf if l == self.nlayers - 1 else main_buf
            buf_h2 = buf_c1 = buf_c2 = main_buf
            if self.rnn_type == 'revlstm':
                buffers.append((buf_h1, buf_h2, buf_c1, buf_c2))
            else:
                buffers.append((buf_h1, buf_h2))

        # Initialize output dictionary.
        output_dict = {'optimal_bits': 0}
        output_dict['normal_bits'] = sum([
            32 * seq_length * batch_size * self.h_size
            for l in range(self.nlayers)
        ])
        if self.rnn_type == 'revlstm':
            output_dict['normal_bits'] *= 2

        # Find last hidden states of model.
        input_seq = self.lockdropi(self.encoder(input_seq))
        saved_hiddens = []
        with torch.no_grad():
            for t in range(len(input_seq)):
                mask = None if lengths is None else (t < lengths).int()
                curr_input = input_seq[t]

                for l, (rnn, buf,
                        hidden) in enumerate(zip(self.rnns, buffers, hiddens)):
                    if l == self.nlayers - 1:
                        saved_hiddens.append(hidden[:, :self.slice_dim])

                    next_hidden, stats = rnn(
                        curr_input, hidden, buf,
                        self.slice_dim if l == self.nlayers - 1 else 0, mask)

                    if l != self.nlayers - 1:
                        curr_input = self.lockdroph(
                            next_hidden['output_hidden'])

                    hiddens[l] = next_hidden['recurrent_hidden']
                    output_dict['optimal_bits'] += stats['optimal_bits']

            saved_hiddens.append(hiddens[-1][:, :self.slice_dim])

        output_dict['last_h'] = hiddens
        output_dict['used_bits'] = main_buf.bit_usage() + slice_buf.bit_usage() +\
            32*self.slice_dim*batch_size*max_length

        return hiddens, saved_hiddens, buffers, main_buf, output_dict
Example #8
0
    def forward(self, input_seq, hiddens):
        """
        Arguments:
            input_seq (LongTensor): of shape (seq_length, batch_size)
            hiddens (list): list of Tensors of length nlayers

        """
        # Intialize information buffers. To limit unused space at the end of each buffer,
        # use the same information buffer for all hiddens of the same size.
        seq_length, batch_size = input_seq.size()
        buffers = [None for _ in range(self.nlayers)]
        if self.use_buffers:
            buffers = []
            unique_buffers = []
            for l in range(self.nlayers):
                if l in [0, self.nlayers - 1]:
                    buf_dim = self.h_size // 2 if l != self.nlayers - 1 else self.in_size // 2
                    buf = InformationBuffer(batch_size, buf_dim,
                                            input_seq.device)
                    buf_tup = (buf,
                               buf) if self.rnn_type == 'revgru' else (buf,
                                                                       buf,
                                                                       buf,
                                                                       buf)
                    unique_buffers.append(buf)
                else:
                    buf_tup = buffers[l - 1]
                buffers.append(buf_tup)

        # Embed input sequence.
        input_seq = embedded_dropout(
            self.embed,
            input_seq,
            dropout=self.dropoute if self.training else 0)
        input_seq = self.lockdrop(input_seq, self.dropouti)

        # Process input sequence through model. Start with finding all hidden states
        # for current layer. Then use these hidden states as inputs to the next layer.
        output_dict = {"optimal_bits": 0}
        last_hiddens = []
        curr_seq = input_seq
        for rnn in self.rnns:
            rnn.set_weights()
        for l, (rnn, buf) in enumerate(zip(self.rnns, buffers)):
            curr_hiddens = []
            prev_hidden = hiddens[l]

            for t in range(len(curr_seq)):
                curr_hidden, stats = rnn(curr_seq[t], prev_hidden, buf)
                prev_hidden = curr_hidden['recurrent_hidden']
                curr_hiddens.append(curr_hidden['output_hidden'])
                output_dict['optimal_bits'] += stats['optimal_bits']

            last_hiddens.append(prev_hidden)
            curr_seq = torch.stack(curr_hiddens,
                                   dim=0)  #[length, batch, hidden]

            if l != self.nlayers - 1:
                curr_seq = self.lockdrop(curr_seq, self.dropouth)

        # Use the last layer hiddens as inputs to our classifier.
        curr_seq = self.lockdrop(curr_seq, self.dropout)
        decoded = self.out(
            curr_seq.view(curr_seq.size(0) * curr_seq.size(1), -1))
        output_dict['decoded'] = decoded.view(curr_seq.size(0),
                                              curr_seq.size(1), -1)
        output_dict['last_h'] = last_hiddens

        # Collect stats over entire sequence.
        if self.use_buffers:
            output_dict['used_bits'] = sum(
                [buf.bit_usage() for buf in unique_buffers])
        output_dict['normal_bits'] = sum([
            32 * seq_length * batch_size *
            (self.h_size if l != self.nlayers - 1 else self.in_size)
            for l in range(self.nlayers)
        ])
        if self.rnn_type == 'revlstm':
            output_dict['normal_bits'] *= 2

        return output_dict
Example #9
0
    h_size = 200
    slice_dim = 30
    in_size = 100
    seq_length = 100

    initial_hidden = ConvertToFixed.apply(torch.randn(batch_size, h_size),
                                          hidden_radix)
    input_seq = torch.randn(seq_length, batch_size, in_size)
    masks = [
        torch.randn(batch_size).bernoulli_(0.5).int()
        for _ in range(seq_length)
    ]

    from buffer import InformationBuffer
    buf_h1 = InformationBuffer(batch_size=batch_size,
                               buf_dim=h_size // 2 - slice_dim,
                               device='cpu')
    buf_h2 = InformationBuffer(batch_size=batch_size,
                               buf_dim=h_size // 2,
                               device='cpu')
    buf = buf_h1, buf_h2
    rnn = RevGRU(in_size, h_size, max_forget=0.96875)

    hidden = initial_hidden
    saved_hiddens = []
    for t in range(seq_length):
        if slice_dim > 0:
            saved_hiddens.append(hidden[:, :slice_dim])
        hidden_dict, _ = rnn(input_seq[t], hidden, buf, slice_dim, masks[t])
        hidden = hidden_dict['recurrent_hidden']
Example #10
0
    def forward(self, input_seq, hiddens):
        """
        Arguments:
            input_seq (LongTensor): of shape (seq_length, batch_size)
            hiddens (list): list of Tensors of length nlayers

        """
        self.set_masks(input_seq.size(1), input_seq.device)

        # Intialize information buffers. To limit unused space at the end of each buffer,
        # use the same information buffer for all hiddens of the same size.
        seq_length, batch_size = input_seq.size()
        buffers = [None for _ in range(self.nlayers)]
        if self.use_buffers:
            buffers = []
            for l in range(self.nlayers):
                if l in [0, self.nlayers - 1]:
                    buf_dim = self.h_size // 2 if l != self.nlayers - 1 else self.in_size // 2
                    buf = InformationBuffer(batch_size, buf_dim,
                                            input_seq.device)
                    buf_tup = (buf,
                               buf) if self.rnn_type == 'revgru' else (buf,
                                                                       buf,
                                                                       buf,
                                                                       buf)
                else:
                    buf_tup = buffers[l - 1]
                buffers.append(buf_tup)

        # Embed input sequence.
        input_seq = self.lockdropi(self.embed_drop(input_seq))

        # Process input sequence through model. Start with finding all hidden states
        # for current layer. Then use these hidden states as inputs to the next layer.
        output_dict = {"optimal_bits": 0}
        last_hiddens = []
        curr_seq = input_seq
        for l, (rnn, buf) in enumerate(zip(self.rnns, buffers)):
            curr_hiddens = []
            prev_hidden = hiddens[l]

            for t in range(len(curr_seq)):
                curr_hidden, stats = rnn(curr_seq[t], prev_hidden, buf)
                prev_hidden = curr_hidden['recurrent_hidden']
                curr_hiddens.append(curr_hidden['output_hidden'])
                output_dict['optimal_bits'] += stats['optimal_bits']

            last_hiddens.append(prev_hidden)
            curr_seq = torch.stack(curr_hiddens,
                                   dim=0)  #[length, batch, hidden]

            if l != self.nlayers - 1:
                curr_seq = self.lockdrophs[l](curr_seq)

        curr_seq = self.lockdrop(curr_seq)
        decoded = self.out(
            curr_seq.view(curr_seq.size(0) * curr_seq.size(1), -1))
        output_dict['decoded'] = decoded.view(curr_seq.size(0),
                                              curr_seq.size(1), -1)
        output_dict['last_h'] = last_hiddens

        return output_dict
Example #11
0
    def forward_and_backward(self, input_seq, target_seq, hiddens):
        """
        Arguments:
            input_seq (LongTensor): of shape (seq_length, batch_size)
            hiddens (tuple): tuple of Tensors of length nlayers
        """
        hiddens = list(hiddens)
        self.set_masks(
            input_seq.size(1),
            input_seq.device)  # COMMENT OUT IF TESTING USING FORWARD

        # Intialize information buffers. To limit unused space at the end of each buffer,
        # use the same information buffer for all hiddens of the same size.
        seq_length, batch_size = input_seq.size()
        buffers = []
        for l in range(self.nlayers):
            if l in [0, self.nlayers - 1]:
                buf_dim = self.h_size // 2 if l != self.nlayers - 1 else self.in_size // 2
                buf = InformationBuffer(batch_size, buf_dim, input_seq.device)
                buf_tup = (buf,
                           buf) if self.rnn_type == 'revgru' else (buf, buf,
                                                                   buf, buf)
            else:
                buf_tup = buffers[l - 1]
            buffers.append(buf_tup)

        # Initialize output dictionary.
        output_dict = {'optimal_bits': 0}
        output_dict['normal_bits'] = sum([
            32 * seq_length * batch_size *
            (self.h_size if l != self.nlayers - 1 else self.in_size)
            for l in range(self.nlayers)
        ])
        if self.rnn_type == 'revlstm':
            output_dict['normal_bits'] *= 2

        # Find last hidden states of model.
        # TODO: figure out way to have wdrop not mask at each step if this takes significant time
        with torch.no_grad():
            for t in range(len(input_seq)):
                curr_input = self.lockdropi(self.embed_drop(input_seq[t]))
                for l, (rnn, buf, lockdroph, hidden) in enumerate(
                        zip(self.rnns, buffers, self.lockdrophs, hiddens)):
                    next_hidden, stats = rnn(curr_input, hidden, buf)
                    if l != self.nlayers - 1:
                        curr_input = lockdroph(next_hidden['output_hidden'])
                    hiddens[l] = next_hidden['recurrent_hidden']
                    output_dict['optimal_bits'] += stats['optimal_bits']

        output_dict['last_h'] = hiddens
        output_dict['used_bits'] = 0
        consumed_bufs = []
        for buf in buffers:
            for group_buf in buf:
                if group_buf not in consumed_bufs:
                    output_dict['used_bits'] += group_buf.bit_usage()
                    consumed_bufs.append(group_buf)

        scaled_ce = lambda output, target: (1. / seq_length) * F.cross_entropy(
            output, target)

        # Loop back through time, reversing computations with help of buffers and using
        # autodiff to compute gradients.
        total_loss = 0
        grad_hiddens = [
            next(self.parameters()).new_zeros(h.size()) for h in hiddens
        ]
        for t in reversed(range(seq_length)):
            top_hidden = hiddens[-1].requires_grad_()
            top_hidden_ = ConvertToFloat.apply(top_hidden[:, :self.in_size],
                                               hidden_radix)
            top_hidden_ = self.lockdrop(top_hidden_)

            output = self.out(top_hidden_)
            last_loss = scaled_ce(output, target_seq[t])
            last_loss.backward()
            grad_hiddens[-1] += top_hidden.grad

            total_loss += last_loss

            for l in reversed(range(self.nlayers)):
                rnn, buf, hidden = self.rnns[l], buffers[l], hiddens[l]

                # Reconstruct previous hidden state.
                with torch.no_grad():
                    if l != 0:
                        curr_input = hiddens[l - 1]
                        drop_input = self.lockdrophs[l - 1](
                            ConvertToFloat.apply(curr_input[:, :self.h_size],
                                                 hidden_radix))
                    else:
                        curr_input = input_seq[t]
                        drop_input = self.lockdropi(
                            self.embed_drop(curr_input))
                    prev_hidden = rnn.reverse(drop_input, hidden, buf)

                # Rerun forwards pass from previous hidden to hidden at time t to construct
                # computation graph and compute gradients.
                prev_hidden.requires_grad_()
                if l != 0:
                    curr_input.requires_grad_()
                    drop_input = self.lockdrophs[l - 1](ConvertToFloat.apply(
                        curr_input[:, :self.h_size], hidden_radix))
                else:
                    drop_input = self.lockdropi(self.embed_drop(curr_input))
                curr_hidden, _ = rnn(drop_input, prev_hidden)
                torch.autograd.backward(curr_hidden['recurrent_hidden'],
                                        grad_tensors=grad_hiddens[l])
                hiddens[l] = prev_hidden.detach()
                grad_hiddens[l] = prev_hidden.grad.data

                if l != 0:
                    grad_hiddens[l - 1] += curr_input.grad.data

        output_dict['loss'] = total_loss
        return output_dict
Example #12
0
    def forward(self, input_seq, target_seq, hiddens, saved_hiddens, main_buf,
                token_weights, padding_idx, dec_lengths, enc_lengths, enc_input_seq, rev_enc):
        # Set up.
        seq_length, batch_size = input_seq.size()
        self.set_masks(batch_size, input_seq.device)

        total_loss = 0.
        num_words = 0.0
        num_correct = 0.0

        # Intialize information buffers. To limit unused space at the end of each buffer,
        # use the same information buffer for all hiddens of the same size.
        # This means using the buffer 'main_buf' from the encoder.
        buffers = []
        if main_buf is None and self.use_buffers:
            main_buf = InformationBuffer(batch_size, self.h_size//2, input_seq.device)
        for l in range(len(hiddens)):
            if self.rnn_type == 'revlstm':
                buffers.append((main_buf, main_buf, main_buf, main_buf)) # Note: main_buf can be None
            else:
                buffers.append((main_buf, main_buf))

        # Initialize output dictionary.
        output_dict = {'optimal_bits': 0}
        output_dict['normal_bits'] = sum([32*seq_length*batch_size*self.h_size for l in range(self.nlayers)])
        if self.rnn_type == 'revlstm':
            output_dict['normal_bits'] *= 2

        # Find last hidden states of model.
        top_hiddens = []
        input_seq = self.lockdropi(self.embedding(input_seq))
        with torch.set_grad_enabled(self.training):
            for t in range(len(input_seq)):
                mask = None if dec_lengths is None else (t < dec_lengths).int()
                curr_input = input_seq[t]

                for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)):
                    next_hidden, stats = rnn(curr_input, hidden, buf, slice_dim=0, mask=mask)
                    if l != self.nlayers-1:
                        curr_input = self.lockdroph(next_hidden['output_hidden'])
                    else:
                        top_hidden_ = next_hidden['output_hidden']

                    hiddens[l] = next_hidden['recurrent_hidden']
                    output_dict['optimal_bits'] += stats['optimal_bits']

                top_hidden_ = self.lockdropo(top_hidden_).unsqueeze(0)

                context = self.construct_context(saved_hiddens, enc_input_seq, rev_enc)
                attn_hidden, _ = self.attn(top_hidden_.transpose(0, 1).contiguous(),
                    context.transpose(0, 1), context_lengths=enc_lengths)
                output = self.generator(attn_hidden[0])

                loss = F.cross_entropy(output, target_seq[t], weight=token_weights, size_average=False)
                total_loss += loss

                non_padding = target_seq[t].ne(padding_idx)
                pred = F.log_softmax(output).max(1)[1]
                num_words += non_padding.sum().item()
                num_correct += pred.eq(target_seq[t]).masked_select(non_padding).sum().item()

        attns = {}
        output_dict['last_h'] = hiddens
        output_dict['used_bits'] = float('inf')
        return total_loss, num_words, num_correct, attns, main_buf, output_dict