def __init__(self, rnn_type, h_size, nlayers, embedding, slice_dim=100, max_forget=0.875, use_buffers=True, dropouti=0, dropouth=0, wdrop=0, context_type='slice'): super(RevEncoder, self).__init__() if rnn_type == 'revgru': rnn = RevGRU module_names = [ 'ih2_to_zr1', 'irh2_to_g1', 'ih1_to_zr2', 'irh1_to_g2' ] elif rnn_type == 'revlstm': rnn = RevLSTM module_names = ['ih2_to_zgfop1', 'ih1_to_zgfop2'] self.rnns = [rnn(h_size, h_size, max_forget) for l in range(nlayers)] self.rnns = nn.ModuleList(self.rnns) self.lockdropi = RevLockedDropout(dropouti, h_size) self.lockdroph = RevLockedDropout(dropouth, h_size) self.rnn_type = rnn_type self.h_size = h_size self.nlayers = nlayers self.encoder = embedding self.slice_dim = slice_dim self.use_buffers = use_buffers self.dropouti = dropouti self.dropouth = dropouth self.wdrop = wdrop self.context_type = context_type if slice_dim == h_size: self.rev_forward = self.rev_forward_allh self.forward = self.rev_forward_allh self.reverse = self.reverse_allh
def __init__(self, rnn_type, h_size, nlayers, embedding, attn_type='general', context_size=None, dropouti=0, dropouth=0, dropouto=0, wdrop=0, dropouts=0, max_forget=0.875, context_type='slice', slice_dim=0, use_buffers=True): super(RevDecoder, self).__init__() if rnn_type == 'revgru': rnn = RevGRU module_names = ['ih2_to_zr1', 'irh2_to_g1', 'ih1_to_zr2', 'irh1_to_g2'] elif rnn_type == 'revlstm': rnn = RevLSTM module_names = ['ih2_to_zgfop1', 'ih1_to_zgfop2'] self.rnns = [rnn(h_size, h_size, max_forget) for l in range(nlayers)] self.rnns = [RevWeightDrop(rnn, module_names, wdrop) for rnn in self.rnns] self.rnns = nn.ModuleList(self.rnns) self.lockdropi = RevLockedDropout(dropouti, h_size) self.lockdroph = RevLockedDropout(dropouth, h_size) self.lockdropo = RevLockedDropout(dropouto, h_size) self.lockdrops = RevLockedDropout(dropouts, slice_dim) # Basic attributes. self.rnn_type = rnn_type self.decoder_type = 'rnn' self.context_type = context_type self.context_size = context_size self.nlayers = nlayers self.h_size = h_size self.embedding = embedding self.wdrop = wdrop self.use_buffers = use_buffers self.generator = nn.Linear(h_size, embedding.num_embeddings) self.slice_dim = slice_dim # Set up the standard attention. self.attn_type = attn_type if attn_type != 'none': self.attn = onmt.modules.MultiSizeAttention(h_size, context_size=context_size, attn_type=attn_type)
class RevEncoder(nn.Module): def __init__(self, rnn_type, h_size, nlayers, embedding, slice_dim=100, max_forget=0.875, use_buffers=True, dropouti=0, dropouth=0, wdrop=0): super(RevEncoder, self).__init__() if rnn_type == 'revgru': rnn = RevGRU module_names = [ 'ih2_to_zr1', 'irh2_to_g1', 'ih1_to_zr2', 'irh1_to_g2' ] elif rnn_type == 'revlstm': rnn = RevLSTM module_names = ['ih2_to_zgfop1', 'ih1_to_zgfop2'] self.rnns = [rnn(h_size, h_size, max_forget) for l in range(nlayers)] self.rnns = [ RevWeightDrop(rnn, module_names, wdrop) for rnn in self.rnns ] self.rnns = nn.ModuleList(self.rnns) self.lockdropi = RevLockedDropout(dropouti, h_size) self.lockdroph = RevLockedDropout(dropouth, h_size) self.rnn_type = rnn_type self.h_size = h_size self.nlayers = nlayers self.encoder = embedding self.slice_dim = slice_dim self.use_buffers = use_buffers self.dropouti = dropouti self.dropouth = dropouth self.wdrop = wdrop def forward(self, input_seq, lengths=None, hiddens=None): """ Arguments: input_seq (LongTensor): size is (seq_length, batch_size) lengths (IntTensor): size is (batch_size,) hiddens (list): list of Tensors of length nlayers """ # Set-up. seq_length, batch_size = input_seq.size() hiddens = self.init_hiddens(batch_size) if hiddens is None else hiddens max_length = len( input_seq) if lengths is None else lengths.max().item() self.set_masks(batch_size, input_seq.device) # Intialize information buffers. To limit unused space at the end of each buffer, # use the same information buffer for all hiddens of the same size. main_buf = InformationBuffer(batch_size, self.h_size // 2, input_seq.device) slice_buf = InformationBuffer(batch_size, self.h_size // 2 - self.slice_dim, input_seq.device) buffers = [] for l in range(len(hiddens)): buf_h1 = slice_buf if l == self.nlayers - 1 else main_buf buf_h2 = buf_c1 = buf_c2 = main_buf if self.rnn_type == 'revlstm': buffers.append((buf_h1, buf_h2, buf_c1, buf_c2)) else: buffers.append((buf_h1, buf_h2)) # Initialize output dictionary. output_dict = {'optimal_bits': 0} output_dict['normal_bits'] = sum([ 32 * seq_length * batch_size * self.h_size for l in range(self.nlayers) ]) if self.rnn_type == 'revlstm': output_dict['normal_bits'] *= 2 # Find last hidden states of model. input_seq = self.lockdropi(self.encoder(input_seq)) saved_hiddens = [] with torch.no_grad(): for t in range(len(input_seq)): mask = None if lengths is None else (t < lengths).int() curr_input = input_seq[t] for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)): if l == self.nlayers - 1: saved_hiddens.append(hidden[:, :self.slice_dim]) next_hidden, stats = rnn( curr_input, hidden, buf, self.slice_dim if l == self.nlayers - 1 else 0, mask) if l != self.nlayers - 1: curr_input = self.lockdroph( next_hidden['output_hidden']) hiddens[l] = next_hidden['recurrent_hidden'] output_dict['optimal_bits'] += stats['optimal_bits'] saved_hiddens.append(hiddens[-1][:, :self.slice_dim]) output_dict['last_h'] = hiddens output_dict['used_bits'] = main_buf.bit_usage() + slice_buf.bit_usage() +\ 32*self.slice_dim*batch_size*max_length return hiddens, saved_hiddens, buffers, main_buf, output_dict def init_hiddens(self, batch_size): weight = next(self.parameters()) if self.rnn_type == 'revlstm': return [ weight.new(torch.zeros(batch_size, 2 * self.h_size)).zero_().int() for l in range(self.nlayers) ] else: return [ weight.new(torch.zeros(batch_size, self.h_size)).zero_().int() for l in range(self.nlayers) ] def set_masks(self, batch_size, device='cuda'): self.lockdropi.set_mask(batch_size, device=device) self.lockdroph.set_mask(batch_size, device=device) for rnn in self.rnns: rnn.set_mask() def reverse(self, input_seq, lengths, last_hiddens, last_hidden_grads, saved_hiddens, saved_hidden_grads, buffers): """ Arguments: input_seq (LongTensor): size is (seq_length, batch_size, 1) lengths (IntTensor): size is (batch_size,) last_hiddens (list): list of IntTensors (each with size (batch_size, h_size)) of length nlayers last_hidden_grads (list): list of FloatTensors (each with size (batch_size, h_size)) of length nlayers saved_hiddens (list): list of IntTensors (each with size (batch_size, slice_dim)) of length seq_length + 1 saved_hidden_grads (list): list of FloatTensors (each with size (batch_size, slice_dim)) of length seq_length + 1 buffers (list): list of InformationBuffers of length nlayers """ hiddens = last_hiddens hidden_grads = last_hidden_grads # TODO(mmackay): replace saved_hidden_grads with just use .grad attribute of the # hiddens in saved_hiddens for t in reversed(range(len(input_seq))): mask = None if lengths is None else (t < lengths).int() for l in reversed(range(self.nlayers)): rnn, buf, hidden = self.rnns[l], buffers[l], hiddens[l] # Reconstruct previous hidden state. with torch.no_grad(): if l != 0: curr_input = hiddens[l - 1] drop_input = self.lockdroph( ConvertToFloat.apply(curr_input[:, :self.h_size], hidden_radix)) else: curr_input = input_seq[t] drop_input = self.lockdropi(self.encoder(curr_input)) prev_hidden = rnn.reverse( drop_input, hidden, buf, self.slice_dim if l == self.nlayers - 1 else 0, saved_hiddens[t] if l == self.nlayers - 1 else None, mask) # Rerun forwards pass from previous hidden to hidden at time t to construct # computation graph and compute gradients. prev_hidden.requires_grad_() if l != 0: curr_input.requires_grad_() drop_input = self.lockdroph( ConvertToFloat.apply(curr_input[:, :self.h_size], hidden_radix)) else: drop_input = self.lockdropi(self.encoder(curr_input)) curr_hidden, _ = rnn(drop_input, prev_hidden, mask=mask) curr_hidden_grad = hidden_grads[l] if l == self.nlayers - 1: curr_hidden_grad[:, :self.slice_dim] += saved_hidden_grads[ t + 1] torch.autograd.backward(curr_hidden['recurrent_hidden'], grad_tensors=curr_hidden_grad) hiddens[l] = prev_hidden.detach() hidden_grads[l] = prev_hidden.grad.data if l != 0: hidden_grads[l - 1] += curr_input.grad.data def test_forward(self, input_seq, lengths=None, hiddens=None): """ Used for testing correctness of gradients in reverse computation. Arguments: input_seq (LongTensor): size is (seq_length, batch_size, 1) lengths (IntTensor): size is (batch_size,) hiddens (list): list of Tensors of length nlayers """ # Set-up. We don't set masks. It is assumed we will call forward before # this method, which will set the masks which should remain the same in this method # to ensure gradients are equal. seq_length, batch_size = input_seq.size() hiddens = self.init_hiddens(batch_size) if hiddens is None else hiddens max_length = len( input_seq) if lengths is None else lengths.max().item() # self.set_masks(batch_size, input_seq.device) # Intialize information buffers. To limit unused space at the end of each buffer, # use the same information buffer for all hiddens of the same size. main_buf = InformationBuffer(batch_size, self.h_size // 2, input_seq.device) slice_buf = InformationBuffer(batch_size, self.h_size // 2 - self.slice_dim, input_seq.device) buffers = [] for l in range(len(hiddens)): buf_h1 = slice_buf if l == self.nlayers - 1 else main_buf buf_h2 = buf_c1 = buf_c2 = main_buf if self.rnn_type == 'revlstm': buffers.append((buf_h1, buf_h2, buf_c1, buf_c2)) else: buffers.append((buf_h1, buf_h2)) # Find last hidden states of model. input_seq = self.lockdropi(self.encoder(input_seq)) saved_hiddens = [] for t in range(len(input_seq)): mask = None if lengths is None else (t < lengths).int() curr_input = input_seq[t] for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)): if l == self.nlayers - 1: saved_hiddens.append(hidden[:, :self.slice_dim]) next_hidden, stats = rnn( curr_input, hidden, buf, self.slice_dim if l == self.nlayers - 1 else 0, mask) if l != self.nlayers - 1: curr_input = self.lockdroph(next_hidden['output_hidden']) hiddens[l] = next_hidden['recurrent_hidden'] saved_hiddens.append(hiddens[-1][:, :self.slice_dim]) return hiddens, saved_hiddens, buffers, {}
class RevDecoder(nn.Module): def __init__(self, rnn_type, h_size, nlayers, embedding, attn_type='general', context_size=None, dropouti=0, dropouth=0, dropouto=0, wdrop=0, dropouts=0, max_forget=0.875, context_type='slice', slice_dim=0, use_buffers=True): super(RevDecoder, self).__init__() if rnn_type == 'revgru': rnn = RevGRU module_names = ['ih2_to_zr1', 'irh2_to_g1', 'ih1_to_zr2', 'irh1_to_g2'] elif rnn_type == 'revlstm': rnn = RevLSTM module_names = ['ih2_to_zgfop1', 'ih1_to_zgfop2'] self.rnns = [rnn(h_size, h_size, max_forget) for l in range(nlayers)] self.rnns = [RevWeightDrop(rnn, module_names, wdrop) for rnn in self.rnns] self.rnns = nn.ModuleList(self.rnns) self.lockdropi = RevLockedDropout(dropouti, h_size) self.lockdroph = RevLockedDropout(dropouth, h_size) self.lockdropo = RevLockedDropout(dropouto, h_size) self.lockdrops = RevLockedDropout(dropouts, slice_dim) # Basic attributes. self.rnn_type = rnn_type self.decoder_type = 'rnn' self.context_type = context_type self.context_size = context_size self.nlayers = nlayers self.h_size = h_size self.embedding = embedding self.wdrop = wdrop self.use_buffers = use_buffers self.generator = nn.Linear(h_size, embedding.num_embeddings) # Set up the standard attention. self.attn_type = attn_type if attn_type != 'none': self.attn = onmt.modules.MultiSizeAttention( h_size, context_size=context_size, attn_type=attn_type) def forward(self, input_seq, hiddens, main_buf, dec_lengths=None): """ Arguments: input_seq (LongTensor): size is (seq_length, batch_size) hiddens (list): list of IntTensors (each with size (batch_size, h_size)) of length nlayers main_buf (InformationBuffer): storage for hidden states with size (batch_size, h_size) dec_lengths (IntTensor): size is (batch_size,) """ # Set up. seq_length, batch_size = input_seq.size() self.set_masks(batch_size, input_seq.device) # Intialize information buffers. To limit unused space at the end of each buffer, # use the same information buffer for all hiddens of the same size. # This means using the buffer 'main_buf' from the encoder. buffers = [] for l in range(len(hiddens)): if self.rnn_type == 'revlstm': buffers.append((main_buf, main_buf, main_buf, main_buf)) else: buffers.append((main_buf, main_buf)) # Initialize output dictionary. output_dict = {'optimal_bits': 0} output_dict['normal_bits'] = sum( [32*seq_length*batch_size*self.h_size for l in range(self.nlayers)]) if self.rnn_type == 'revlstm': output_dict['normal_bits'] *= 2 output_dict['hid_seq'] = [] for l in range(self.nlayers): output_dict['hid_seq'].append([hiddens[l].data.clone().numpy()]) # Find last hidden states of model. input_seq = self.lockdropi(self.embedding(input_seq)) with torch.no_grad(): for t in range(len(input_seq)): mask = None if dec_lengths is None else (t < dec_lengths).int() curr_input = input_seq[t] for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)): next_hidden, stats = rnn(curr_input, hidden, buf, slice_dim=0, mask=mask) if l != self.nlayers-1: curr_input = self.lockdroph(next_hidden['output_hidden']) hiddens[l] = next_hidden['recurrent_hidden'] output_dict['hid_seq'][l].append(hiddens[l].data.clone().numpy()) output_dict['optimal_bits'] += stats['optimal_bits'] output_dict['last_h'] = hiddens output_dict['used_bits'] = float('inf') # TODO(mmackay): figure out right way to compute bits return hiddens, buffers, output_dict def reverse(self, input_seq, target_seq, last_hiddens, saved_hiddens, buffers, token_weights, dec_lengths, enc_lengths, enc_input_seq, rev_enc): """ Arguments: input_seq (LongTensor): size is (seq_length, batch_size) hiddens (list): list of IntTensors (each with size (batch_size, h_size)) of length nlayers main_buf (InformationBuffer): storage for hidden states with size (batch_size, h_size) dec_lengths (IntTensor): size is (batch_size,) """ hiddens = last_hiddens hidden_grads = [next(self.parameters()).new_zeros(h.size()) for h in hiddens] loss_fun = lambda output, target: F.cross_entropy( output, target, weight=token_weights, size_average=False) saved_hiddens = [hidden.requires_grad_() for hidden in saved_hiddens] total_loss = 0. output_dict = {'hid_seq': []} for l in range(self.nlayers): output_dict['hid_seq'].append([last_hiddens[l].data.clone().numpy()]) output_dict['drop_hs'] = [] for t in reversed(range(len(input_seq))): top_hidden = hiddens[-1].requires_grad_() top_hidden_ = ConvertToFloat.apply(top_hidden[:,:self.h_size], hidden_radix) top_hidden_ = self.lockdropo(top_hidden_).unsqueeze(0) output_dict['drop_hs'].append(top_hidden_.data.clone().numpy()) context = self.construct_context(saved_hiddens, enc_input_seq, rev_enc) attn_hidden, _ = self.attn(top_hidden_.transpose(0, 1).contiguous(), context.transpose(0, 1), context_lengths=enc_lengths) output = self.generator(attn_hidden[0]) last_loss = loss_fun(output, target_seq[t]) last_loss.backward() hidden_grads[-1] += top_hidden.grad total_loss += last_loss mask = None if dec_lengths is None else (t < dec_lengths).int() for l in reversed(range(self.nlayers)): rnn, buf, hidden = self.rnns[l], buffers[l], hiddens[l] # Reconstruct previous hidden state. with torch.no_grad(): if l != 0: curr_input = hiddens[l-1] drop_input = self.lockdroph(ConvertToFloat.apply( curr_input[:, :self.h_size], hidden_radix)) else: curr_input = input_seq[t].squeeze() drop_input = self.lockdropi(self.embedding(curr_input)) prev_hidden = rnn.reverse(drop_input, hidden, buf, slice_dim=0, saved_hidden=None, mask=mask) # Rerun forwards pass from previous hidden to hidden at time t to construct # computation graph and compute gradients. prev_hidden.requires_grad_() if l != 0: curr_input.requires_grad_() drop_input = self.lockdroph(ConvertToFloat.apply( curr_input[:, :self.h_size], hidden_radix)) else: drop_input = self.lockdropi(self.embedding(curr_input)) curr_hidden, _ = rnn(drop_input, prev_hidden, mask=mask) torch.autograd.backward( curr_hidden['recurrent_hidden'], grad_tensors=hidden_grads[l]) hiddens[l] = prev_hidden.detach() hidden_grads[l] = prev_hidden.grad.data if l != 0: hidden_grads[l-1] += curr_input.grad.data output_dict['hid_seq'][l].append(prev_hidden.data.clone().numpy()) return hiddens, hidden_grads, saved_hiddens, buffers, total_loss, output_dict def construct_context(self, saved_hiddens, enc_input_seq, rev_enc): if self.context_type == 'emb': context = rev_enc.lockdropi(rev_enc.encoder(enc_input_seq)) elif self.context_type == 'slice': context = self.lockdrops(ConvertToFloat.apply( torch.stack(saved_hiddens[1:]), hidden_radix)) elif self.context_type == 'slice_emb': embs = rev_enc.lockdropi(rev_enc.encoder(enc_input_seq)) slices = self.lockdrops(ConvertToFloat.apply( torch.stack(saved_hiddens[1:]), hidden_radix)) context = torch.cat([embs, slices], dim=2) return context def test_forward(self, input_seq, hiddens, main_buf, dec_lengths=None): """ Arguments: input_seq (LongTensor): size is (seq_length, batch_size) hiddens (list): list of IntTensors (each with size (batch_size, h_size)) of length nlayers main_buf (InformationBuffer): storage for hidden states with size (batch_size, h_size) dec_lengths (IntTensor): size is (batch_size,) """ # Set up. seq_length, batch_size = input_seq.size() # self.set_masks(batch_size, input_seq.device) # Intialize information buffers. To limit unused space at the end of each buffer, # use the same information buffer for all hiddens of the same size. # This means using the buffer 'main_buf' from the encoder. buffers = [] for l in range(len(hiddens)): if self.rnn_type == 'revlstm': buffers.append((main_buf, main_buf, main_buf, main_buf)) else: buffers.append((main_buf, main_buf)) output_dict = {'hid_seq': []} for l in range(self.nlayers): output_dict['hid_seq'].append([hiddens[l].data.clone().numpy()]) # Find last hidden states of model. input_seq = self.lockdropi(self.embedding(input_seq)) top_hiddens = [] for t in range(len(input_seq)): mask = None if dec_lengths is None else (t < dec_lengths).int() curr_input = input_seq[t] for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)): next_hidden, stats = rnn(curr_input, hidden, buf, slice_dim=0, mask=mask) if l != self.nlayers-1: curr_input = self.lockdroph(next_hidden['output_hidden']) hiddens[l] = next_hidden['recurrent_hidden'] output_dict['hid_seq'][l].append(hiddens[l].data.clone().numpy()) if l == self.nlayers-1: top_hiddens.append(next_hidden['output_hidden']) return top_hiddens, hiddens, output_dict def test_compute_loss(self, top_hiddens, target_seq, saved_hiddens, token_weights, enc_lengths, enc_input_seq, rev_enc): """ Used to test correctness of gradients in reverse computation. """ top_hiddens = self.lockdropo(torch.stack(top_hiddens)) output_dict = {'drop_hs': []} for t in range(len(top_hiddens)): output_dict['drop_hs'].append(top_hiddens[t].data.clone().numpy()) if self.context_type == 'emb': context = rev_enc.lockdropi(rev_enc.encoder(enc_input_seq)) elif self.context_type == 'slice': context = self.lockdrops(ConvertToFloat.apply( torch.stack(saved_hiddens[1:]), hidden_radix)) elif self.context_type == 'slice_emb': embs = rev_enc.lockdropi(rev_enc.encoder(enc_input_seq)) slices = self.lockdrops(ConvertToFloat.apply( torch.stack(saved_hiddens[1:]), hidden_radix)) context = torch.cat([embs, slices], dim=2) attn_hiddens, _ = self.attn(top_hiddens.transpose(0, 1).contiguous(), context.transpose(0, 1), context_lengths=enc_lengths) output = self.generator(attn_hiddens.view(-1, attn_hiddens.size(2))) loss = F.cross_entropy(output, target_seq.view(-1), weight=token_weights, size_average=False) return loss, output_dict @property def _input_size(self): """ Private helper returning the number of expected features. """ return self.embeddings.embedding_dim # IMPORTANT: This is where the issues are def set_masks(self, batch_size, device='cuda'): self.lockdropi.set_mask(batch_size, device=device) self.lockdroph.set_mask(batch_size, device=device) self.lockdropo.set_mask(batch_size, device=device) self.lockdrops.set_mask(batch_size, device=device) for rnn in self.rnns: rnn.set_mask() def init_hiddens(self, batch_size): ''' For testing purposes, should never be called in practice''' weight = next(self.parameters()) if self.rnn_type == 'revlstm': return [weight.new(torch.zeros(batch_size, 2 * self.h_size)).zero_().int() for l in range(self.nlayers)] else: return [weight.new(torch.zeros(batch_size, self.h_size)).zero_().int() for l in range(self.nlayers)]
class RevEncoder(nn.Module): def __init__(self, rnn_type, h_size, nlayers, embedding, slice_dim=100, max_forget=0.875, use_buffers=True, dropouti=0, dropouth=0, wdrop=0, context_type='slice'): super(RevEncoder, self).__init__() if rnn_type == 'revgru': rnn = RevGRU module_names = [ 'ih2_to_zr1', 'irh2_to_g1', 'ih1_to_zr2', 'irh1_to_g2' ] elif rnn_type == 'revlstm': rnn = RevLSTM module_names = ['ih2_to_zgfop1', 'ih1_to_zgfop2'] self.rnns = [rnn(h_size, h_size, max_forget) for l in range(nlayers)] self.rnns = nn.ModuleList(self.rnns) self.lockdropi = RevLockedDropout(dropouti, h_size) self.lockdroph = RevLockedDropout(dropouth, h_size) self.rnn_type = rnn_type self.h_size = h_size self.nlayers = nlayers self.encoder = embedding self.slice_dim = slice_dim self.use_buffers = use_buffers self.dropouti = dropouti self.dropouth = dropouth self.wdrop = wdrop self.context_type = context_type if slice_dim == h_size: self.rev_forward = self.rev_forward_allh self.forward = self.rev_forward_allh self.reverse = self.reverse_allh def forward_test(self, input_seq, lengths=None, hiddens=None): """ Arguments: input_seq (LongTensor): size is (seq_length, batch_size) lengths (IntTensor): size is (batch_size,) hiddens (list): list of Tensors of length nlayers """ # Set-up. seq_length, batch_size = input_seq.size() hiddens = self.init_hiddens(batch_size) if hiddens is None else hiddens max_length = len( input_seq) if lengths is None else lengths.max().item() self.set_masks(batch_size, input_seq.device) # Find last hidden states of model. input_seq = self.lockdropi(self.encoder(input_seq)) all_hiddens = [] with torch.set_grad_enabled(self.training): for t in range(len(input_seq)): mask = None if lengths is None else (t < lengths).int() curr_input = input_seq[t] for l, (rnn, hidden) in enumerate(zip(self.rnns, hiddens)): if l == self.nlayers - 1: all_hiddens.append(hidden) next_hidden, stats = rnn( curr_input, hidden, buf=None, slice_dim=self.slice_dim if l == self.nlayers - 1 else 0, mask=mask) if l != self.nlayers - 1: curr_input = self.lockdroph( next_hidden['output_hidden']) hiddens[l] = next_hidden['recurrent_hidden'] all_hiddens.append(hiddens[-1]) hidden_context = ConvertToFloat.apply(torch.stack(all_hiddens[1:]), hidden_radix) if self.context_type == 'hidden': context = hidden_context elif self.context_type == 'emb': context = input_seq elif self.context_type == 'slice': context = hidden_context[:, :, :self.slice_dim] elif self.context_type == 'slice_emb': context = torch.cat( [input_seq, hidden_context[:, :, :self.slice_dim]], dim=2) return hiddens, context def forward(self, input_seq, lengths=None, hiddens=None): """ Arguments: input_seq (LongTensor): size is (seq_length, batch_size) lengths (IntTensor): size is (batch_size,) hiddens (list): list of Tensors of length nlayers """ # Set-up. seq_length, batch_size = input_seq.size() hiddens = self.init_hiddens(batch_size) if hiddens is None else hiddens max_length = len( input_seq) if lengths is None else lengths.max().item() self.set_masks(batch_size, input_seq.device) # Intialize information buffers. To limit unused space at the end of each buffer, # use the same information buffer for all hiddens of the same size. main_buf = InformationBuffer(batch_size, self.h_size // 2, input_seq.device) slice_buf = InformationBuffer(batch_size, self.h_size // 2 - self.slice_dim, input_seq.device) buffers = [] for l in range(len(hiddens)): if self.training and self.use_buffers: buf_h1 = slice_buf if l == self.nlayers - 1 else main_buf buf_h2 = buf_c1 = buf_c2 = main_buf else: buf_h1 = buf_h2 = buf_c1 = buf_c2 = None if self.rnn_type == 'revlstm': buffers.append((buf_h1, buf_h2, buf_c1, buf_c2)) else: buffers.append((buf_h1, buf_h2)) # Initialize output dictionary. output_dict = {'optimal_bits': 0} output_dict['normal_bits'] = sum([ 32 * seq_length * batch_size * self.h_size for l in range(self.nlayers) ]) if self.rnn_type == 'revlstm': output_dict['normal_bits'] *= 2 # Find last hidden states of model. input_seq = self.lockdropi(self.encoder(input_seq)) saved_hiddens = [] with torch.set_grad_enabled(self.training): for t in range(len(input_seq)): mask = None if lengths is None else (t < lengths).int() curr_input = input_seq[t] for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)): if l == self.nlayers - 1: saved_hiddens.append(hidden[:, :self.slice_dim]) next_hidden, stats = rnn( curr_input, hidden, buf, self.slice_dim if l == self.nlayers - 1 else 0, mask) if l != self.nlayers - 1: curr_input = self.lockdroph( next_hidden['output_hidden']) hiddens[l] = next_hidden['recurrent_hidden'] output_dict['optimal_bits'] += stats['optimal_bits'] saved_hiddens.append(hiddens[-1][:, :self.slice_dim]) return hiddens, saved_hiddens, buffers, main_buf, slice_buf, output_dict def rev_forward(self, input_seq, lengths=None, hiddens=None): """ Arguments: input_seq (LongTensor): size is (seq_length, batch_size) lengths (IntTensor): size is (batch_size,) hiddens (list): list of Tensors of length nlayers """ # Set-up. seq_length, batch_size = input_seq.size() hiddens = self.init_hiddens(batch_size) if hiddens is None else hiddens max_length = len( input_seq) if lengths is None else lengths.max().item() self.set_masks(batch_size, input_seq.device) # Intialize information buffers. To limit unused space at the end of each buffer, # use the same information buffer for all hiddens of the same size. main_buf = InformationBuffer(batch_size, self.h_size // 2, input_seq.device) slice_buf = InformationBuffer(batch_size, self.h_size // 2 - self.slice_dim, input_seq.device) buffers = [] for l in range(len(hiddens)): buf_h1 = slice_buf if l == self.nlayers - 1 else main_buf buf_h2 = buf_c1 = buf_c2 = main_buf if self.rnn_type == 'revlstm': buffers.append((buf_h1, buf_h2, buf_c1, buf_c2)) else: buffers.append((buf_h1, buf_h2)) # Initialize output dictionary. output_dict = {'optimal_bits': 0} output_dict['normal_bits'] = sum([ 32 * seq_length * batch_size * self.h_size for l in range(self.nlayers) ]) if self.rnn_type == 'revlstm': output_dict['normal_bits'] *= 2 # Find last hidden states of model. input_seq = self.lockdropi(self.encoder(input_seq)) saved_hiddens = [] with torch.no_grad(): for t in range(len(input_seq)): mask = None if lengths is None else (t < lengths).int() curr_input = input_seq[t] for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)): if l == self.nlayers - 1 and self.slice_dim > 0: saved_hiddens.append(hidden[:, :self.slice_dim]) next_hidden, stats = rnn( curr_input, hidden, buf, self.slice_dim if l == self.nlayers - 1 else 0, mask) if l != self.nlayers - 1: curr_input = self.lockdroph( next_hidden['output_hidden']) hiddens[l] = next_hidden['recurrent_hidden'] output_dict['optimal_bits'] += stats['optimal_bits'] if self.slice_dim > 0: saved_hiddens.append(hiddens[-1][:, :self.slice_dim]) output_dict['last_h'] = hiddens output_dict['used_bits'] = main_buf.bit_usage() + slice_buf.bit_usage() +\ 32*self.slice_dim*batch_size*max_length return hiddens, saved_hiddens, buffers, main_buf, slice_buf, output_dict def rev_forward_allh(self, input_seq, lengths=None, hiddens=None): """ Arguments: input_seq (LongTensor): size is (seq_length, batch_size) lengths (IntTensor): size is (batch_size,) hiddens (list): list of Tensors of length nlayers """ # Set-up. seq_length, batch_size = input_seq.size() hiddens = self.init_hiddens(batch_size) if hiddens is None else hiddens max_length = len( input_seq) if lengths is None else lengths.max().item() self.set_masks(batch_size, input_seq.device) # Intialize information buffers. To limit unused space at the end of each buffer, # use the same information buffer for all hiddens of the same size. main_buf = InformationBuffer(batch_size, self.h_size // 2, input_seq.device) buffers = [] for l in range(len(hiddens)): if l == self.nlayers - 1: buffers.append(None) else: buf_h1 = buf_h2 = buf_c1 = buf_c2 = main_buf if self.rnn_type == 'revlstm': buffers.append((buf_h1, buf_h2, buf_c1, buf_c2)) else: buffers.append((buf_h1, buf_h2)) # Find last hidden states of model. input_seq = self.lockdropi(self.encoder(input_seq)) saved_hiddens = [] with torch.no_grad(): for t in range(len(input_seq)): mask = None if lengths is None else (t < lengths).int() curr_input = input_seq[t] for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)): if l == self.nlayers - 1: saved_hiddens.append(hidden.detach().clone()) next_hidden, stats = rnn(curr_input, hidden, buf, 0, mask) if l != self.nlayers - 1: curr_input = self.lockdroph( next_hidden['output_hidden']) hiddens[l] = next_hidden['recurrent_hidden'] saved_hiddens.append(hiddens[-1]) total_bits = sum([ 32 * seq_length * batch_size * self.h_size for l in range(self.nlayers) ]) output_dict = { 'normal_bits': total_bits, 'optimal_bits': total_bits, 'used_bits': total_bits } if self.rnn_type == 'revlstm': for k in ['normal_bits', 'optimal_bits', 'used_bits']: output_dict[k] *= 2 output_dict['last_h'] = hiddens return hiddens, saved_hiddens, buffers, main_buf, None, output_dict def init_hiddens(self, batch_size): weight = next(self.parameters()) if self.rnn_type == 'revlstm': return [ weight.new(batch_size, 2 * self.h_size).zero_().int() for l in range(self.nlayers) ] else: return [ weight.new(batch_size, self.h_size).zero_().int() for l in range(self.nlayers) ] def set_masks(self, batch_size, device='cuda'): self.lockdropi.set_mask(batch_size, device=device) self.lockdroph.set_mask(batch_size, device=device) def reverse(self, input_seq, lengths, last_hiddens, last_hidden_grads, saved_hiddens, buffers): """ Arguments: input_seq (LongTensor): size is (seq_length, batch_size, 1) lengths (IntTensor): size is (batch_size,) last_hiddens (list): list of IntTensors (each with size (batch_size, h_size)) of length nlayers last_hidden_grads (list): list of FloatTensors (each with size (batch_size, h_size)) of length nlayers saved_hiddens (list): list of IntTensors (each with size (batch_size, slice_dim)) of length seq_length + 1 buffers (list): list of InformationBuffers of length nlayers """ hiddens = last_hiddens hidden_grads = last_hidden_grads for t in reversed(range(len(input_seq))): mask = None if lengths is None else (t < lengths).int() for l in reversed(range(self.nlayers)): rnn, buf, hidden = self.rnns[l], buffers[l], hiddens[l] # Reconstruct previous hidden state. with torch.no_grad(): if l != 0: curr_input = hiddens[l - 1] drop_input = self.lockdroph( ConvertToFloat.apply(curr_input[:, :self.h_size], hidden_radix)) else: curr_input = input_seq[t] drop_input = self.lockdropi(self.encoder(curr_input)) prev_hidden = rnn.reverse( drop_input, hidden, buf, self.slice_dim if l == self.nlayers - 1 else 0, saved_hiddens[t] if l == self.nlayers - 1 and self.slice_dim > 0 else None, mask) # Rerun forwards pass from previous hidden to hidden at time t to construct # computation graph and compute gradients. prev_hidden.requires_grad_() if l != 0: curr_input.requires_grad_() drop_input = self.lockdroph( ConvertToFloat.apply(curr_input[:, :self.h_size], hidden_radix)) else: drop_input = self.lockdropi(self.encoder(curr_input)) curr_hidden, _ = rnn(drop_input, prev_hidden, mask=mask) curr_hidden_grad = hidden_grads[l] if l == self.nlayers - 1 and self.slice_dim > 0: curr_hidden_grad[:, :self.slice_dim] += saved_hiddens[ t + 1].grad torch.autograd.backward(curr_hidden['recurrent_hidden'], grad_tensors=curr_hidden_grad) hiddens[l] = prev_hidden.detach() hidden_grads[l] = prev_hidden.grad.data if l != 0: hidden_grads[l - 1] += curr_input.grad.data def reverse_allh(self, input_seq, lengths, last_hiddens, last_hidden_grads, saved_hiddens, buffers): """ Arguments: input_seq (LongTensor): size is (seq_length, batch_size, 1) lengths (IntTensor): size is (batch_size,) last_hiddens (list): list of IntTensors (each with size (batch_size, h_size)) of length nlayers last_hidden_grads (list): list of FloatTensors (each with size (batch_size, h_size)) of length nlayers saved_hiddens (list): list of IntTensors (each with size (batch_size, slice_dim)) of length seq_length + 1 buffers (list): list of InformationBuffers of length nlayers """ hiddens = last_hiddens hidden_grads = last_hidden_grads for t in reversed(range(len(input_seq))): mask = None if lengths is None else (t < lengths).int() for l in reversed(range(self.nlayers)): rnn, buf, hidden = self.rnns[l], buffers[l], hiddens[l] # Reconstruct previous hidden state. with torch.no_grad(): if l != 0: curr_input = hiddens[l - 1] drop_input = self.lockdroph( ConvertToFloat.apply(curr_input[:, :self.h_size], hidden_radix)) else: curr_input = input_seq[t] drop_input = self.lockdropi(self.encoder(curr_input)) prev_hidden = rnn.reverse( drop_input, hidden, buf, 0, saved_hiddens[t] if l == self.nlayers - 1 else None, mask) # Rerun forwards pass from previous hidden to hidden at time t to construct # computation graph and compute gradients. prev_hidden.requires_grad_() if l != 0: curr_input.requires_grad_() drop_input = self.lockdroph( ConvertToFloat.apply(curr_input[:, :self.h_size], hidden_radix)) else: drop_input = self.lockdropi(self.encoder(curr_input)) curr_hidden, _ = rnn(drop_input, prev_hidden, mask=mask) curr_hidden_grad = hidden_grads[l] if l == self.nlayers - 1: curr_hidden_grad += saved_hiddens[t + 1].grad torch.autograd.backward(curr_hidden['recurrent_hidden'], grad_tensors=curr_hidden_grad) hiddens[l] = prev_hidden.detach() hidden_grads[l] = prev_hidden.grad.data if l != 0: hidden_grads[l - 1] += curr_input.grad.data
class RevDecoder(nn.Module): def __init__(self, rnn_type, h_size, nlayers, embedding, attn_type='general', context_size=None, dropouti=0, dropouth=0, dropouto=0, wdrop=0, dropouts=0, max_forget=0.875, context_type='slice', slice_dim=0, use_buffers=True): super(RevDecoder, self).__init__() if rnn_type == 'revgru': rnn = RevGRU module_names = ['ih2_to_zr1', 'irh2_to_g1', 'ih1_to_zr2', 'irh1_to_g2'] elif rnn_type == 'revlstm': rnn = RevLSTM module_names = ['ih2_to_zgfop1', 'ih1_to_zgfop2'] self.rnns = [rnn(h_size, h_size, max_forget) for l in range(nlayers)] self.rnns = [RevWeightDrop(rnn, module_names, wdrop) for rnn in self.rnns] self.rnns = nn.ModuleList(self.rnns) self.lockdropi = RevLockedDropout(dropouti, h_size) self.lockdroph = RevLockedDropout(dropouth, h_size) self.lockdropo = RevLockedDropout(dropouto, h_size) self.lockdrops = RevLockedDropout(dropouts, slice_dim) # Basic attributes. self.rnn_type = rnn_type self.decoder_type = 'rnn' self.context_type = context_type self.context_size = context_size self.nlayers = nlayers self.h_size = h_size self.embedding = embedding self.wdrop = wdrop self.use_buffers = use_buffers self.generator = nn.Linear(h_size, embedding.num_embeddings) self.slice_dim = slice_dim # Set up the standard attention. self.attn_type = attn_type if attn_type != 'none': self.attn = onmt.modules.MultiSizeAttention(h_size, context_size=context_size, attn_type=attn_type) def forward_test(self, input_seq, hiddens, context, enc_lengths, dec_lengths=None): # Set up. seq_length, batch_size = input_seq.size() self.set_masks(batch_size, input_seq.device) # Find last hidden states of model. top_hiddens = [] input_seq = self.lockdropi(self.embedding(input_seq)) with torch.set_grad_enabled(self.training): for t in range(len(input_seq)): mask = None if dec_lengths is None else (t < dec_lengths).int() curr_input = input_seq[t] for l, (rnn, hidden) in enumerate(zip(self.rnns, hiddens)): next_hidden, stats = rnn(curr_input, hidden, buf=None, slice_dim=0, mask=mask) if l != self.nlayers-1: curr_input = self.lockdroph(next_hidden['output_hidden']) else: top_hiddens.append(next_hidden['output_hidden']) hiddens[l] = next_hidden['recurrent_hidden'] top_hiddens = self.lockdropo(torch.stack(top_hiddens)) attn_hiddens, attn_scores = self.attn(top_hiddens.transpose(0, 1).contiguous(), context.transpose(0, 1), context_lengths=enc_lengths) attns = {"std": attn_scores} output = self.generator(attn_hiddens.view(-1, attn_hiddens.size(2))) return output, hiddens, attns def forward(self, input_seq, target_seq, hiddens, saved_hiddens, main_buf, token_weights, padding_idx, dec_lengths, enc_lengths, enc_input_seq, rev_enc): # Set up. seq_length, batch_size = input_seq.size() self.set_masks(batch_size, input_seq.device) total_loss = 0. num_words = 0.0 num_correct = 0.0 # Intialize information buffers. To limit unused space at the end of each buffer, # use the same information buffer for all hiddens of the same size. # This means using the buffer 'main_buf' from the encoder. buffers = [] if main_buf is None and self.use_buffers: main_buf = InformationBuffer(batch_size, self.h_size//2, input_seq.device) for l in range(len(hiddens)): if self.rnn_type == 'revlstm': buffers.append((main_buf, main_buf, main_buf, main_buf)) # Note: main_buf can be None else: buffers.append((main_buf, main_buf)) # Initialize output dictionary. output_dict = {'optimal_bits': 0} output_dict['normal_bits'] = sum([32*seq_length*batch_size*self.h_size for l in range(self.nlayers)]) if self.rnn_type == 'revlstm': output_dict['normal_bits'] *= 2 # Find last hidden states of model. top_hiddens = [] input_seq = self.lockdropi(self.embedding(input_seq)) with torch.set_grad_enabled(self.training): for t in range(len(input_seq)): mask = None if dec_lengths is None else (t < dec_lengths).int() curr_input = input_seq[t] for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)): next_hidden, stats = rnn(curr_input, hidden, buf, slice_dim=0, mask=mask) if l != self.nlayers-1: curr_input = self.lockdroph(next_hidden['output_hidden']) else: top_hidden_ = next_hidden['output_hidden'] hiddens[l] = next_hidden['recurrent_hidden'] output_dict['optimal_bits'] += stats['optimal_bits'] top_hidden_ = self.lockdropo(top_hidden_).unsqueeze(0) context = self.construct_context(saved_hiddens, enc_input_seq, rev_enc) attn_hidden, _ = self.attn(top_hidden_.transpose(0, 1).contiguous(), context.transpose(0, 1), context_lengths=enc_lengths) output = self.generator(attn_hidden[0]) loss = F.cross_entropy(output, target_seq[t], weight=token_weights, size_average=False) total_loss += loss non_padding = target_seq[t].ne(padding_idx) pred = F.log_softmax(output).max(1)[1] num_words += non_padding.sum().item() num_correct += pred.eq(target_seq[t]).masked_select(non_padding).sum().item() attns = {} output_dict['last_h'] = hiddens output_dict['used_bits'] = float('inf') return total_loss, num_words, num_correct, attns, main_buf, output_dict def rev_forward(self, input_seq, hiddens, main_buf, dec_lengths=None): """ Arguments: input_seq (LongTensor): size is (seq_length, batch_size) hiddens (list): list of IntTensors (each with size (batch_size, h_size)) of length nlayers main_buf (InformationBuffer): storage for hidden states with size (batch_size, h_size) dec_lengths (IntTensor): size is (batch_size,) """ # Set up. seq_length, batch_size = input_seq.size() self.set_masks(batch_size, input_seq.device) # Intialize information buffers. To limit unused space at the end of each buffer, # use the same information buffer for all hiddens of the same size. # This means using the buffer 'main_buf' from the encoder. buffers = [] if main_buf is None: main_buf = InformationBuffer(batch_size, self.h_size//2, input_seq.device) for l in range(len(hiddens)): if self.rnn_type == 'revlstm': buffers.append((main_buf, main_buf, main_buf, main_buf)) else: buffers.append((main_buf, main_buf)) # Initialize output dictionary. output_dict = {'optimal_bits': 0} output_dict['normal_bits'] = sum( [32*seq_length*batch_size*self.h_size for l in range(self.nlayers)]) if self.rnn_type == 'revlstm': output_dict['normal_bits'] *= 2 # Find last hidden states of model. input_seq = self.lockdropi(self.embedding(input_seq)) with torch.no_grad(): for t in range(len(input_seq)): mask = None if dec_lengths is None else (t < dec_lengths).int() curr_input = input_seq[t] for l, (rnn, buf, hidden) in enumerate(zip(self.rnns, buffers, hiddens)): next_hidden, stats = rnn(curr_input, hidden, buf, slice_dim=0, mask=mask) if l != self.nlayers-1: curr_input = self.lockdroph(next_hidden['output_hidden']) hiddens[l] = next_hidden['recurrent_hidden'] output_dict['optimal_bits'] += stats['optimal_bits'] output_dict['last_h'] = hiddens output_dict['used_bits'] = float('inf') # TODO(mmackay): figure out right way to compute bits return hiddens, buffers, main_buf, output_dict def reverse(self, input_seq, target_seq, last_hiddens, saved_hiddens, buffers, token_weights, padding_idx, dec_lengths, enc_lengths, enc_input_seq, rev_enc): """ Arguments: input_seq (LongTensor): size is (dec_seq_length, batch_size) target_seq (IntTensor): size is (dec_seq_length, batch_size) last_hiddens (list): list of IntTensors (each with size (batch_size, h_size)) of length nlayers saved_hiddens (list): list of IntTensors (each with size (batch_size, slice_dim)) of length enc_seq_length buffers (list): list of InformationBuffers of length nlayers token_weights (FloatTensor): of size (dec_ntokens,) dec_lengths (IntTensor): size is (batch_size,) enc_lengths (IntTensor): size is (batch_size,) enc_input_seq (LongTensor): size is (enc_seq_length, batch_size) rev_enc (RevEncoder): RevEncoder module used before the RevDecoder """ batch_size = float(enc_lengths.shape[0]) hiddens = last_hiddens hidden_grads = [next(self.parameters()).new_zeros(h.size()) for h in hiddens] loss_fun = lambda output, target: F.cross_entropy(output, target, weight=token_weights, size_average=False) saved_hiddens = [hidden.requires_grad_() for hidden in saved_hiddens] total_loss = 0. num_words = 0.0 num_correct = 0.0 for t in reversed(range(len(input_seq))): top_hidden = hiddens[-1].requires_grad_() top_hidden_ = ConvertToFloat.apply(top_hidden[:,:self.h_size], hidden_radix) top_hidden_ = self.lockdropo(top_hidden_).unsqueeze(0) context = self.construct_context(saved_hiddens, enc_input_seq, rev_enc) attn_hidden, _ = self.attn(top_hidden_.transpose(0, 1).contiguous(), context.transpose(0, 1), context_lengths=enc_lengths) output = self.generator(attn_hidden[0]) last_loss = loss_fun(output, target_seq[t]) last_loss.div(batch_size).backward() hidden_grads[-1] += top_hidden.grad total_loss += last_loss non_padding = target_seq[t].ne(padding_idx) pred = F.log_softmax(output).max(1)[1] num_words += non_padding.sum().item() num_correct += pred.eq(target_seq[t]).masked_select(non_padding).sum().item() mask = None if dec_lengths is None else (t < dec_lengths).int() for l in reversed(range(self.nlayers)): rnn, buf, hidden = self.rnns[l], buffers[l], hiddens[l] # Reconstruct previous hidden state. with torch.no_grad(): if l != 0: curr_input = hiddens[l-1] drop_input = self.lockdroph(ConvertToFloat.apply( curr_input[:, :self.h_size], hidden_radix)) else: curr_input = input_seq[t].squeeze() drop_input = self.lockdropi(self.embedding(curr_input)) prev_hidden = rnn.reverse(drop_input, hidden, buf, slice_dim=0, saved_hidden=None, mask=mask) # Rerun forwards pass from previous hidden to hidden at time t to construct # computation graph and compute gradients. prev_hidden.requires_grad_() if l != 0: curr_input.requires_grad_() drop_input = self.lockdroph(ConvertToFloat.apply( curr_input[:, :self.h_size], hidden_radix)) else: drop_input = self.lockdropi(self.embedding(curr_input)) curr_hidden, _ = rnn(drop_input, prev_hidden, mask=mask) torch.autograd.backward( curr_hidden['recurrent_hidden'], grad_tensors=hidden_grads[l]) hiddens[l] = prev_hidden.detach() hidden_grads[l] = prev_hidden.grad.data if l != 0: hidden_grads[l-1] += curr_input.grad.data return total_loss, num_words, num_correct, hiddens, hidden_grads, saved_hiddens def construct_context(self, saved_hiddens, enc_input_seq, rev_enc): if self.context_type == 'emb': context = rev_enc.lockdropi(rev_enc.encoder(enc_input_seq)) elif self.context_type == 'slice': context = ConvertToFloat.apply(torch.stack(saved_hiddens[1:]), hidden_radix) context = context[:,:,:self.slice_dim] context = self.lockdrops(context) elif self.context_type == 'slice_emb': embs = rev_enc.lockdropi(rev_enc.encoder(enc_input_seq)) slices = ConvertToFloat.apply(torch.stack(saved_hiddens[1:]), hidden_radix) slices = slices[:,:,:self.slice_dim] slices = self.lockdrops(slices) context = torch.cat([embs, slices], dim=2) return context def set_masks(self, batch_size, device='cuda'): self.lockdropi.set_mask(batch_size, device=device) self.lockdroph.set_mask(batch_size, device=device) self.lockdropo.set_mask(batch_size, device=device) self.lockdrops.set_mask(batch_size, device=device) for rnn in self.rnns: rnn.set_mask() @property def _input_size(self): """ Private helper returning the number of expected features. """ return self.embeddings.embedding_dim