class CharLevel_autoencoder(nn.Module): def __init__(self, criterion, num_symbols, use_cuda): #, seq_len): super(CharLevel_autoencoder, self).__init__() self.char_embedding_dim = 64 self.pooling_stride = 5 self.seq_len = 200 self.num_symbols = num_symbols self.use_cuda = use_cuda self.filter_widths = list(range(1, 8)) self.num_filters_per_width = 125 #[100, 100, 125, 125, 150, 150, 150, 150] self.encoder_embedding = nn.Embedding(num_symbols, self.char_embedding_dim) self.cnn_encoder = cnn_encoder( filter_widths=self.filter_widths, num_filters_per_width=self.num_filters_per_width, char_embedding_dim=self.char_embedding_dim) #seq_len = self.seq_len) self.decoder_hidden_size = len( self.filter_widths) * self.num_filters_per_width self.rnn_encoder = rnn_encoder(hidden_size=self.decoder_hidden_size) # decoder embedding dim dictated by output dim of encoder self.decoder_embedding = nn.Embedding(num_symbols, self.decoder_hidden_size) self.attention_decoder = AttnDecoderRNN( num_symbols=num_symbols, hidden_size=self.decoder_hidden_size, output_size=self.seq_len // self.pooling_stride) # if use_cuda: # self.cnn_encoder = self.cnn_encoder.cuda() # self.rnn_encoder = self.rnn_encoder.cuda() # self.attention_decoder = self.attention_decoder.cuda() self.criterion = criterion def encode(self, data, seq_len, collect_filters=False): encoder_embedded = self.encoder_embedding(data).unsqueeze(1).transpose( 2, 3) encoded = self.cnn_encoder.forward(encoder_embedded, self.seq_len, collect_filters) encoded = encoded.squeeze(2) encoder_hidden = self.rnn_encoder.initHidden() encoder_outputs = Variable( torch.zeros(64, seq_len // self.pooling_stride, 2 * self.decoder_hidden_size)) if self.use_cuda: encoder_outputs = encoder_outputs.cuda() encoder_hidden = encoder_hidden.cuda() for symbol_ind in range(self.seq_len // self.pooling_stride): #self.rnn_emits_len): output, encoder_hidden = self.rnn_encoder.forward( encoded[:, :, symbol_ind], encoder_hidden) #print(output.data.shape) # (81, 64, 128) encoder_outputs[:, symbol_ind, :] = output[0] return encoder_outputs, encoder_hidden def decode(self, noisy_data, target_data, encoder_hidden, encoder_outputs, seq_len): loss = 0 decoder_hidden = encoder_hidden #print(target_data.data.shape) for amino_acid_index in range(self.seq_len): target_amino_acid = target_data[:, :, amino_acid_index] #.long() decoder_input = noisy_data.data[:, amino_acid_index].unsqueeze( 1) #.transpose(0,1) decoder_embedded = self.decoder_embedding(decoder_input) # # current symbol, current hidden state, outputs from encoder decoder_output, decoder_hidden, attn_weights = self.attention_decoder.forward( decoder_embedded, decoder_hidden, encoder_outputs, self.seq_len // self.pooling_stride) #print(decoder_output.data.shape, target_amino_acid.data.shape) # torch.Size([64, 23]) loss += self.criterion(decoder_output, Variable(target_amino_acid)) return loss # preliminary model # class cnn_autoencoder(nn.Module): # def __init__(self): # super(cnn_autoencoder, self).__init__() # self.encoder = cnn_encoder() # self.decoder = cnn_decoder() # self.embedding = nn.Embedding(22, 4) # def encode(self, data): # char_embeddings = self.embedding(data).unsqueeze(1).transpose(2,3) # encoded, unpool_indices = self.encoder.forward(char_embeddings) # return encoded, unpool_indices # def decode(self, data, unpool_indices): # reconstructed = self.decoder.forward(data, unpool_indices) # return reconstructed
class CharLevel_autoencoder(nn.Module): def __init__(self, criterion, num_symbols, use_cuda): ''' overview of autoencoder forward: 1. Input batch is embedded 2. CNN+Pool encoder is called on input 3. BiGRU encoder is called on activations of previous encoder 4. Attention GRU decoder takes an embedded symbol at current t - Decoder embedding embeds symbol at current t 6. Batch cross entropy is calculated and returned ''' super(CharLevel_autoencoder, self).__init__() self.char_embedding_dim = 128 self.pooling_stride = 5 self.seq_len = 300 self.num_symbols = num_symbols self.use_cuda = use_cuda self.filter_widths = list(range(1, 9)) # due to cuda limitations, every filter width has 50 less filters self.num_filters_per_width = [150, 150, 200, 200, 250, 250, 250, 250] self.encoder_embedding = nn.Embedding(num_symbols, self.char_embedding_dim) self.cnn_encoder = cnn_encoder( filter_widths=self.filter_widths, num_filters_per_width=self.num_filters_per_width, char_embedding_dim=self.char_embedding_dim, use_cuda=use_cuda) self.decoder_hidden_size = int( np.sum(np.array(self.num_filters_per_width))) self.rnn_encoder = rnn_encoder(hidden_size=self.decoder_hidden_size) # decoder embedding dim dictated by output dim of encoder self.decoder_embedding = nn.Embedding(num_symbols, self.decoder_hidden_size) self.attention_decoder = AttnDecoderRNN( num_symbols=num_symbols, hidden_size=self.decoder_hidden_size, output_size=self.seq_len // self.pooling_stride) self.criterion = criterion def encode(self, data, seq_len): encoder_embedded = self.encoder_embedding(data).unsqueeze(1).transpose( 2, 3) encoded = self.cnn_encoder.forward(encoder_embedded, self.seq_len) encoded = encoded.squeeze(2) encoder_hidden = self.rnn_encoder.initHidden() encoder_outputs = Variable( torch.zeros(64, seq_len // self.pooling_stride, 2 * self.decoder_hidden_size)) if self.use_cuda: encoder_outputs = encoder_outputs.cuda() encoder_hidden = encoder_hidden.cuda() for symbol_ind in range(self.seq_len // self.pooling_stride): #self.rnn_emits_len): output, encoder_hidden = self.rnn_encoder.forward( encoded[:, :, symbol_ind], encoder_hidden) encoder_outputs[:, symbol_ind, :] = output[0] return encoder_outputs, encoder_hidden def decode(self, target_data, decoder_hidden, encoder_outputs, i): use_teacher_forcing = True if random.random() < 0.7 else False if type( i ) != bool: # given batch index, then eval mode, no teacher forcing use_teacher_forcing = False output = [] # SOS token = 32 after encoding it input_embedded = Variable(torch.LongTensor([32]).repeat(64), requires_grad=False) if self.use_cuda: input_embedded = input_embedded.cuda() input_embedded = self.decoder_embedding(input_embedded) for symbol_index in range(self.seq_len): # # current symbol, current hidden state, outputs from encoder decoder_output, decoder_hidden, attn_weights = self.attention_decoder.forward( input_embedded, decoder_hidden, encoder_outputs) output.append(decoder_output) if use_teacher_forcing: input_symbol = Variable(target_data[:, symbol_index], requires_grad=False) if self.use_cuda: input_symbol = input_symbol.cuda() else: values, input_symbol = decoder_output.max(1) input_embedded = self.decoder_embedding(input_symbol) # at current batch: conglomerate all true and predicted symbols # into one vector then return the batch cross entropy # first mask out padding at the end of every sentence actual_sentence_mask = torch.ne(target_data, 31).byte() threeD_mask = actual_sentence_mask.unsqueeze(2).repeat( 1, 1, 125) #.transpose() predicted = torch.stack(output, dim=1) # if validation loader is called, dump predictions if type(i) != bool: values, indices = predicted.max(2) print(indices.data.shape) pickle.dump(indices.data.numpy(), open("./data/%s_predicted.p" % (i), "wb"), protocol=4) if self.use_cuda: target_data, actual_sentence_mask, threeD_mask = target_data.cuda( ), actual_sentence_mask.cuda(), threeD_mask.cuda() # calculate cross entropy on non-padding symbols masked_target = torch.masked_select(target_data, actual_sentence_mask) predicted = predicted.masked_select(Variable(threeD_mask), ) predicted = predicted.view(-1, 125) loss = self.criterion(predicted, Variable(masked_target, )) return loss
class CharLevel_autoencoder(nn.Module): def __init__(self, criterion, num_symbols, use_cuda): #, seq_len): super(CharLevel_autoencoder, self).__init__() self.char_embedding_dim = 128 self.pooling_stride = 5 self.seq_len = 300 self.num_symbols = num_symbols self.use_cuda = use_cuda self.filter_widths = list(range(1, 9)) self.num_filters_per_width = [150, 150, 200, 200, 250, 250, 250, 250] self.encoder_embedding = nn.Embedding(num_symbols, self.char_embedding_dim) self.cnn_encoder = cnn_encoder( filter_widths=self.filter_widths, num_filters_per_width=self.num_filters_per_width, char_embedding_dim=self.char_embedding_dim, use_cuda=use_cuda) self.decoder_hidden_size = int( np.sum(np.array(self.num_filters_per_width))) self.rnn_encoder = rnn_encoder(hidden_size=self.decoder_hidden_size) # decoder embedding dim dictated by output dim of encoder self.decoder_embedding = nn.Embedding(num_symbols, self.decoder_hidden_size) self.attention_decoder = AttnDecoderRNN( num_symbols=num_symbols, hidden_size=self.decoder_hidden_size, output_size=self.seq_len // self.pooling_stride) self.criterion = criterion def encode(self, data, seq_len): encoder_embedded = self.encoder_embedding(data).unsqueeze(1).transpose( 2, 3) encoded = self.cnn_encoder.forward(encoder_embedded, self.seq_len) encoded = encoded.squeeze(2) encoder_hidden = self.rnn_encoder.initHidden() encoder_outputs = Variable( torch.zeros(64, seq_len // self.pooling_stride, 2 * self.decoder_hidden_size)) if self.use_cuda: encoder_outputs = encoder_outputs.cuda() encoder_hidden = encoder_hidden.cuda() for symbol_ind in range(self.seq_len // self.pooling_stride): #self.rnn_emits_len): output, encoder_hidden = self.rnn_encoder.forward( encoded[:, :, symbol_ind], encoder_hidden) #print(output.data.shape) # (81, 64, 128) encoder_outputs[:, symbol_ind, :] = output[0] return encoder_outputs, encoder_hidden def decode(self, target_data, decoder_hidden, encoder_outputs, i): use_teacher_forcing = True if random.random() < 0.7 else False if type( i ) != bool: # given batch index, then eval mode, no teacher forcing use_teacher_forcing = False #print(use_teacher_forcing) output = [] # SOS token = 32 after encoding it input_embedded = Variable(torch.LongTensor([32]).repeat(64), requires_grad=False) if self.use_cuda: input_embedded = input_embedded.cuda() input_embedded = self.decoder_embedding(input_embedded) for symbol_index in range(self.seq_len): # # current symbol, current hidden state, outputs from encoder decoder_output, decoder_hidden, attn_weights = self.attention_decoder.forward( input_embedded, decoder_hidden, encoder_outputs) output.append(decoder_output) if use_teacher_forcing: input_symbol = Variable(target_data[:, symbol_index], requires_grad=False) if self.use_cuda: input_symbol = input_symbol.cuda() else: values, input_symbol = decoder_output.max(1) input_embedded = self.decoder_embedding(input_symbol) actual_sentence_mask = torch.ne(target_data, 31).byte() threeD_mask = actual_sentence_mask.unsqueeze(2).repeat( 1, 1, 125) #.transpose() #print(actual_sentence_mask.shape, threeD_mask.shape) predicted = torch.stack(output, dim=1) if type(i) != bool: values, indices = predicted.max(2) print(indices.data.shape) pickle.dump(indices.data.numpy(), open("./data/%s_predicted.p" % (i), "wb"), protocol=4) #print(predicted.data.shape, target_data.shape) if self.use_cuda: target_data, actual_sentence_mask, threeD_mask = target_data.cuda( ), actual_sentence_mask.cuda(), threeD_mask.cuda() masked_target = torch.masked_select(target_data, actual_sentence_mask) predicted = predicted.masked_select(Variable(threeD_mask), ) predicted = predicted.view(-1, 125) loss = self.criterion(predicted, Variable(masked_target, )) return loss