def __init__(self, criterion, num_symbols, use_cuda):  #, seq_len):
        super(CharLevel_autoencoder, self).__init__()
        self.char_embedding_dim = 128
        self.pooling_stride = 5
        self.seq_len = 300
        self.num_symbols = num_symbols
        self.use_cuda = use_cuda

        self.filter_widths = list(range(1, 9))
        self.num_filters_per_width = [150, 150, 200, 200, 250, 250, 250, 250]

        self.encoder_embedding = nn.Embedding(num_symbols,
                                              self.char_embedding_dim)
        self.cnn_encoder = cnn_encoder(
            filter_widths=self.filter_widths,
            num_filters_per_width=self.num_filters_per_width,
            char_embedding_dim=self.char_embedding_dim,
            use_cuda=use_cuda)

        self.decoder_hidden_size = int(
            np.sum(np.array(self.num_filters_per_width)))
        self.rnn_encoder = rnn_encoder(hidden_size=self.decoder_hidden_size)

        # decoder embedding dim dictated by output dim of encoder
        self.decoder_embedding = nn.Embedding(num_symbols,
                                              self.decoder_hidden_size)
        self.attention_decoder = AttnDecoderRNN(
            num_symbols=num_symbols,
            hidden_size=self.decoder_hidden_size,
            output_size=self.seq_len // self.pooling_stride)

        self.criterion = criterion
    def __init__(self, input_size, output_size, hidden_size, learning_rate,
                 teacher_forcing_ratio, device):
        super(Seq2Seq, self).__init__()

        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.device = device

        self.encoder = EncoderRNN(input_size, hidden_size)
        self.decoder = AttnDecoderRNN(hidden_size, output_size)

        self.encoder_optimizer = optim.SGD(self.encoder.parameters(),
                                           lr=learning_rate)
        self.decoder_optimizer = optim.SGD(self.decoder.parameters(),
                                           lr=learning_rate)

        self.criterion = nn.NLLLoss()
Example #3
0
def trainDemo(lang, dataSet, nlVocab, codeVocab, train_variables):
    print("Training...")
    encoder1 = EncoderRNN(codeVocab.n_words, setting.HIDDDEN_SIAZE)
    attn_decoder1 = AttnDecoderRNN(setting.HIDDDEN_SIAZE,
                                   nlVocab.n_words,
                                   1,
                                   dropout_p=0.1)

    if setting.USE_CUDA:
        encoder1 = encoder1.cuda()
        attn_decoder1 = attn_decoder1.cuda()

    trainIters(lang,
               dataSet,
               train_variables,
               encoder1,
               attn_decoder1,
               2000000,
               print_every=5000)
    def __init__(self, criterion, num_symbols, use_cuda):
        ''' overview of autoencoder forward:
            1. Input batch is embedded 
            2. CNN+Pool encoder is called on input
            3. BiGRU encoder is called on activations of previous encoder
            4. Attention GRU decoder takes an embedded symbol at current t 
                  - Decoder embedding embeds symbol at current t 
            6. Batch cross entropy is calculated and returned  
            '''
        super(CharLevel_autoencoder, self).__init__()
        self.char_embedding_dim = 128
        self.pooling_stride = 5
        self.seq_len = 300
        self.num_symbols = num_symbols
        self.use_cuda = use_cuda

        self.filter_widths = list(range(1, 9))
        # due to cuda limitations, every filter width has 50 less filters
        self.num_filters_per_width = [150, 150, 200, 200, 250, 250, 250, 250]

        self.encoder_embedding = nn.Embedding(num_symbols,
                                              self.char_embedding_dim)
        self.cnn_encoder = cnn_encoder(
            filter_widths=self.filter_widths,
            num_filters_per_width=self.num_filters_per_width,
            char_embedding_dim=self.char_embedding_dim,
            use_cuda=use_cuda)

        self.decoder_hidden_size = int(
            np.sum(np.array(self.num_filters_per_width)))
        self.rnn_encoder = rnn_encoder(hidden_size=self.decoder_hidden_size)

        # decoder embedding dim dictated by output dim of encoder
        self.decoder_embedding = nn.Embedding(num_symbols,
                                              self.decoder_hidden_size)
        self.attention_decoder = AttnDecoderRNN(
            num_symbols=num_symbols,
            hidden_size=self.decoder_hidden_size,
            output_size=self.seq_len // self.pooling_stride)

        self.criterion = criterion
Example #5
0
    def __init__(self, criterion, num_symbols, use_cuda):  #, seq_len):
        super(CharLevel_autoencoder, self).__init__()
        self.char_embedding_dim = 64
        self.pooling_stride = 5
        self.seq_len = 200
        self.num_symbols = num_symbols
        self.use_cuda = use_cuda

        self.filter_widths = list(range(1, 8))
        self.num_filters_per_width = 125  #[100, 100, 125, 125, 150, 150, 150, 150]

        self.encoder_embedding = nn.Embedding(num_symbols,
                                              self.char_embedding_dim)
        self.cnn_encoder = cnn_encoder(
            filter_widths=self.filter_widths,
            num_filters_per_width=self.num_filters_per_width,
            char_embedding_dim=self.char_embedding_dim)
        #seq_len = self.seq_len)

        self.decoder_hidden_size = len(
            self.filter_widths) * self.num_filters_per_width
        self.rnn_encoder = rnn_encoder(hidden_size=self.decoder_hidden_size)

        # decoder embedding dim dictated by output dim of encoder
        self.decoder_embedding = nn.Embedding(num_symbols,
                                              self.decoder_hidden_size)
        self.attention_decoder = AttnDecoderRNN(
            num_symbols=num_symbols,
            hidden_size=self.decoder_hidden_size,
            output_size=self.seq_len // self.pooling_stride)

        # if use_cuda:
        #       self.cnn_encoder = self.cnn_encoder.cuda()
        #       self.rnn_encoder = self.rnn_encoder.cuda()
        #       self.attention_decoder = self.attention_decoder.cuda()

        self.criterion = criterion
# initialize a network and start training.
#
# Remember that the input sentences were heavily filtered. For this small
# dataset we can use relatively small networks of 256 hidden nodes and a
# single GRU layer. After about 40 minutes on a MacBook CPU we'll get some
# reasonable results.
#
# .. Note::
#    If you run this notebook you can train, interrupt the kernel,
#    evaluate, and continue training later. Comment out the lines where the
#    encoder and decoder are initialized and run ``trainIters`` again.
#

hidden_size = 256
encoder1 = EncoderRNN(input_lang.n_words, hidden_size)
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1)

TRAIN = False
if "-t" in sys.argv:
    TRAIN = True

TRAIN_ITER = 7500
if len(sys.argv) == 3:
    TRAIN_ITER = int(sys.argv[2])

if use_cuda:
    encoder1 = encoder1.cuda()
    attn_decoder1 = attn_decoder1.cuda()

if os.path.exists("encoder.pt") and os.path.exists("decoder.pt") and not TRAIN:
    print("Found saved models")
class CharLevel_autoencoder(nn.Module):
    def __init__(self, criterion, num_symbols, use_cuda):
        ''' overview of autoencoder forward:
            1. Input batch is embedded 
            2. CNN+Pool encoder is called on input
            3. BiGRU encoder is called on activations of previous encoder
            4. Attention GRU decoder takes an embedded symbol at current t 
                  - Decoder embedding embeds symbol at current t 
            6. Batch cross entropy is calculated and returned  
            '''
        super(CharLevel_autoencoder, self).__init__()
        self.char_embedding_dim = 128
        self.pooling_stride = 5
        self.seq_len = 300
        self.num_symbols = num_symbols
        self.use_cuda = use_cuda

        self.filter_widths = list(range(1, 9))
        # due to cuda limitations, every filter width has 50 less filters
        self.num_filters_per_width = [150, 150, 200, 200, 250, 250, 250, 250]

        self.encoder_embedding = nn.Embedding(num_symbols,
                                              self.char_embedding_dim)
        self.cnn_encoder = cnn_encoder(
            filter_widths=self.filter_widths,
            num_filters_per_width=self.num_filters_per_width,
            char_embedding_dim=self.char_embedding_dim,
            use_cuda=use_cuda)

        self.decoder_hidden_size = int(
            np.sum(np.array(self.num_filters_per_width)))
        self.rnn_encoder = rnn_encoder(hidden_size=self.decoder_hidden_size)

        # decoder embedding dim dictated by output dim of encoder
        self.decoder_embedding = nn.Embedding(num_symbols,
                                              self.decoder_hidden_size)
        self.attention_decoder = AttnDecoderRNN(
            num_symbols=num_symbols,
            hidden_size=self.decoder_hidden_size,
            output_size=self.seq_len // self.pooling_stride)

        self.criterion = criterion

    def encode(self, data, seq_len):
        encoder_embedded = self.encoder_embedding(data).unsqueeze(1).transpose(
            2, 3)
        encoded = self.cnn_encoder.forward(encoder_embedded, self.seq_len)
        encoded = encoded.squeeze(2)

        encoder_hidden = self.rnn_encoder.initHidden()
        encoder_outputs = Variable(
            torch.zeros(64, seq_len // self.pooling_stride,
                        2 * self.decoder_hidden_size))
        if self.use_cuda:
            encoder_outputs = encoder_outputs.cuda()
            encoder_hidden = encoder_hidden.cuda()

        for symbol_ind in range(self.seq_len //
                                self.pooling_stride):  #self.rnn_emits_len):
            output, encoder_hidden = self.rnn_encoder.forward(
                encoded[:, :, symbol_ind], encoder_hidden)
            encoder_outputs[:, symbol_ind, :] = output[0]
        return encoder_outputs, encoder_hidden

    def decode(self, target_data, decoder_hidden, encoder_outputs, i):
        use_teacher_forcing = True if random.random() < 0.7 else False
        if type(
                i
        ) != bool:  # given batch  index, then eval mode, no teacher forcing
            use_teacher_forcing = False

        output = []
        # SOS token = 32 after encoding it
        input_embedded = Variable(torch.LongTensor([32]).repeat(64),
                                  requires_grad=False)
        if self.use_cuda:
            input_embedded = input_embedded.cuda()
        input_embedded = self.decoder_embedding(input_embedded)

        for symbol_index in range(self.seq_len):
            # # current symbol, current hidden state, outputs from encoder
            decoder_output, decoder_hidden, attn_weights = self.attention_decoder.forward(
                input_embedded, decoder_hidden, encoder_outputs)
            output.append(decoder_output)

            if use_teacher_forcing:
                input_symbol = Variable(target_data[:, symbol_index],
                                        requires_grad=False)
                if self.use_cuda:
                    input_symbol = input_symbol.cuda()

            else:
                values, input_symbol = decoder_output.max(1)
            input_embedded = self.decoder_embedding(input_symbol)

        # at current batch: conglomerate all true and predicted symbols
        # into one vector then return the batch cross entropy
        # first mask out padding at the end of every sentence
        actual_sentence_mask = torch.ne(target_data, 31).byte()
        threeD_mask = actual_sentence_mask.unsqueeze(2).repeat(
            1, 1, 125)  #.transpose()
        predicted = torch.stack(output, dim=1)

        # if validation loader is called, dump predictions
        if type(i) != bool:
            values, indices = predicted.max(2)
            print(indices.data.shape)
            pickle.dump(indices.data.numpy(),
                        open("./data/%s_predicted.p" % (i), "wb"),
                        protocol=4)

        if self.use_cuda:
            target_data, actual_sentence_mask, threeD_mask = target_data.cuda(
            ), actual_sentence_mask.cuda(), threeD_mask.cuda()

        # calculate cross entropy on non-padding symbols
        masked_target = torch.masked_select(target_data, actual_sentence_mask)
        predicted = predicted.masked_select(Variable(threeD_mask), )
        predicted = predicted.view(-1, 125)
        loss = self.criterion(predicted, Variable(masked_target, ))

        return loss
Example #8
0
        pair = random.choice(pairs)
        print(pair)
        print('>', pair[2])
        print('=', pair[0])
        output_words, attentions = evaluate(encoder, decoder, pair[2])
        output_sentence = ' '.join(output_words)
        print('<', output_sentence)
        print('')


voc_path = abs_file_path + "/data/data_clean.txt"
voc = Voc("total")
voc.initVoc(voc_path)
pairs = prepareData(abs_file_path)
print(len(pairs))

hidden_size = 256
encoder1 = EncoderRNN(voc.num_words, hidden_size).to(device)
attn_decoder1 = AttnDecoderRNN(hidden_size, voc.num_words,
                               dropout=0.1).to(device)
trainIters(encoder1, attn_decoder1, 75000)

encoder_save_path = "encoder3.pth"
decoder_save_path = "decoder3.pth"
torch.save(encoder1, current_dir + '/' + encoder_save_path)
torch.save(attn_decoder1, current_dir + "/" + decoder_save_path)
model1 = torch.load(current_dir + "/" + encoder_save_path)
model2 = torch.load(current_dir + "/" + decoder_save_path)
evaluateRandomly(model1.to(torch.device("cpu")),
                 model2.to(torch.device("cpu")))
Example #9
0
model = True
hidden_size = 256
if model == True:
    day = 19
    hour = "01"
    nowTime = '2018-12-' + str(day) + '-' + str(hour)
    encoder_save_path = "model/combineEncoder+" + nowTime + "+.pth"
    decoder_save_path = "model/combineDecoder+" + nowTime + "+.pth"
    combiner_save_path = "model/combineCombiner+" + nowTime + "+.pth"
    encoder1 = torch.load(current_dir + "/" + encoder_save_path)
    attn_decoder1 = torch.load(current_dir + "/" + decoder_save_path)
    CombineEncoder = torch.load(current_dir + "/" + combiner_save_path)
else:

    encoder1 = EncoderRNN(voc.num_words, hidden_size).to(device)
    attn_decoder1 = AttnDecoderRNN(hidden_size, voc.num_words).to(device)
    CombineEncoder = CombineEncoderRNN(hidden_size, hidden_size).to(device)
trainIters(encoder1, attn_decoder1, CombineEncoder, trainpairs, 30)

encoder_save_path = "model/AttencombineEncoder+" + nowTime + "hidden" + str(
    hidden_size) + "+.pth"
decoder_save_path = "model/AttencombineDecoder+" + nowTime + "hidden" + str(
    hidden_size) + "+.pth"
combiner_save_path = "model/AttencombineCombiner+" + nowTime + "hidden" + str(
    hidden_size) + "+.pth"
torch.save(encoder1, current_dir + '/' + encoder_save_path)
torch.save(attn_decoder1, current_dir + "/" + decoder_save_path)

torch.save(CombineEncoder, current_dir + "/" + combiner_save_path)
model1 = torch.load(current_dir + "/" + encoder_save_path)
model2 = torch.load(current_dir + "/" + decoder_save_path)
Example #10
0
            torch.save(decoder.state_dict(), args.save_path + '/decoder')



    #showPlot(plot_losses, plot_losses_test)




######################################################################
# Training
# =======================
hidden_size = 200 

encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, MAX_LENGTH, dropout_p=0.1).to(device)


torch.save(input_lang, args.save_path + '/input_lang')
torch.save(output_lang, args.save_path + '/output_lang')
torch.save(test_set, args.save_path + '/test_set')

print(args.print_every)
trainIters(encoder1, attn_decoder1, args.n_iters,  args.print_every, args.plot_every, save_every=args.save_every)

torch.save(encoder1.state_dict(), args.save_path + '/encoder')
torch.save(attn_decoder1.state_dict(), args.save_path + '/decoder')



class Seq2Seq(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, learning_rate,
                 teacher_forcing_ratio, device):
        super(Seq2Seq, self).__init__()

        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.device = device

        self.encoder = EncoderRNN(input_size, hidden_size)
        self.decoder = AttnDecoderRNN(hidden_size, output_size)

        self.encoder_optimizer = optim.SGD(self.encoder.parameters(),
                                           lr=learning_rate)
        self.decoder_optimizer = optim.SGD(self.decoder.parameters(),
                                           lr=learning_rate)

        self.criterion = nn.NLLLoss()

    def train(self,
              input_tensor,
              target_tensor,
              max_length=constants.MAX_LENGTH):
        encoder_hidden = self.encoder.initHidden()

        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()

        input_length = input_tensor.size(0)
        target_length = target_tensor.size(0)

        encoder_outputs = torch.zeros(max_length + 1,
                                      self.encoder.hidden_size,
                                      device=self.device)

        loss = 0

        for ei in range(input_length):
            encoder_output, encoder_hidden = self.encoder(
                input_tensor[ei], encoder_hidden)
            encoder_outputs[ei] = encoder_output[0, 0]

        decoder_input = torch.tensor([[constants.SOS_TOKEN]],
                                     device=self.device)
        decoder_hidden = encoder_hidden

        use_teacher_forcing = True if np.random.random(
        ) < self.teacher_forcing_ratio else False

        if use_teacher_forcing:
            # Teacher forcing: feed the target as the next input
            for di in range(target_length):
                decoder_output, decoder_hidden, decoder_attention = self.decoder(
                    decoder_input, decoder_hidden, encoder_outputs)
                loss += self.criterion(decoder_output, target_tensor[di])
                decoder_input = target_tensor[di]  # Teacher forcing
        else:
            # Without teacher forcing: use its own prediction as the next input
            for di in range(target_length):
                decoder_output, decoder_hidden, decoder_attention = self.decoder(
                    decoder_input, decoder_hidden, encoder_outputs)
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze().detach(
                )  # detach from history as input

                loss += self.criterion(decoder_output, target_tensor[di])

                if decoder_input.item() == constants.EOS_TOKEN:
                    break

        loss.backward()

        self.encoder_optimizer.step()
        self.decoder_optimizer.step()

        return loss.item() / target_length

    def trainIters(self, env, evaluator):
        start_total_time = time.time() - env.total_training_time
        start_epoch_time = time.time()  # Reset every LOG_EVERY iterations
        start_train_time = time.time()  # Reset every LOG_EVERY iterations
        total_loss = 0  # Reset every LOG_EVERY iterations

        for iter in range(env.iters_completed + 1, constants.NUM_ITER + 1):
            row = env.train_methods.iloc[np.random.randint(
                len(env.train_methods))]
            input_tensor = row['source']
            target_tensor = row['name']

            loss = self.train(input_tensor, target_tensor)
            total_loss += loss

            if iter % constants.LOG_EVERY == 0:
                log('Completed {} iterations'.format(iter))

                train_time_elapsed = time.time() - start_train_time

                log('Evaluating on validation set')
                start_eval_time = time.time()

                names = evaluator.evaluate(self)
                # save_dataframe(names, constants.VALIDATION_NAMES_FILE)

                eval_time_elapsed = time.time() - start_eval_time

                env.history = env.history.append(
                    {
                        'Loss': total_loss / constants.LOG_EVERY,
                        'BLEU': names['BLEU'].mean(),
                        'ROUGE': names['ROUGE'].mean(),
                        'F1': names['F1'].mean(),
                        'num_names': len(names['GeneratedName'].unique())
                    },
                    ignore_index=True)

                epoch_time_elapsed = time.time() - start_epoch_time
                total_time_elapsed = time.time() - start_total_time

                env.total_training_time = total_time_elapsed

                history_last_row = env.history.iloc[-1]

                log_dict = OrderedDict([
                    ("Iteration", '{}/{} ({:.1f}%)'.format(
                        iter, constants.NUM_ITER,
                        iter / constants.NUM_ITER * 100)),
                    ("Average loss", history_last_row['Loss']),
                    ("Average BLEU", history_last_row['BLEU']),
                    ("Average ROUGE", history_last_row['ROUGE']),
                    ("Average F1", history_last_row['F1']),
                    ("Unique names", int(history_last_row['num_names'])),
                    ("Epoch time", time_str(epoch_time_elapsed)),
                    ("Training time", time_str(train_time_elapsed)),
                    ("Evaluation time", time_str(eval_time_elapsed)),
                    ("Total training time", time_str(total_time_elapsed))
                ])

                write_training_log(log_dict, constants.TRAIN_LOG_FILE)
                plot_and_save_histories(env.history)

                env.iters_completed = iter
                env.save_train()

                # Reseting counters
                total_loss = 0
                start_epoch_time = time.time()
                start_train_time = time.time()

    def forward(self,
                input_tensor,
                max_length=constants.MAX_LENGTH,
                return_attention=False):
        encoder_hidden = self.encoder.initHidden()

        input_length = input_tensor.size(0)

        encoder_outputs = torch.zeros(max_length + 1,
                                      self.encoder.hidden_size,
                                      device=self.device)

        for ei in range(input_length):
            encoder_output, encoder_hidden = self.encoder(
                input_tensor[ei], encoder_hidden)
            encoder_outputs[ei] = encoder_output[0, 0]

        decoder_input = torch.tensor([[constants.SOS_TOKEN]],
                                     device=self.device)
        decoder_hidden = encoder_hidden

        decoded_words = []
        attention_vectors = []

        for di in range(max_length):
            decoder_output, decoder_hidden, decoder_attention = self.decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            topv, topi = decoder_output.data.topk(1)

            decoded_words.append(topi.item())
            attention_vectors.append(decoder_attention.tolist()[0])

            if decoded_words[-1] == constants.EOS_TOKEN:
                break

            decoder_input = topi.squeeze().detach()

        if return_attention:
            return decoded_words, attention_vectors
        else:
            return decoded_words
Example #12
0
class CharLevel_autoencoder(nn.Module):
    def __init__(self, criterion, num_symbols, use_cuda):  #, seq_len):
        super(CharLevel_autoencoder, self).__init__()
        self.char_embedding_dim = 64
        self.pooling_stride = 5
        self.seq_len = 200
        self.num_symbols = num_symbols
        self.use_cuda = use_cuda

        self.filter_widths = list(range(1, 8))
        self.num_filters_per_width = 125  #[100, 100, 125, 125, 150, 150, 150, 150]

        self.encoder_embedding = nn.Embedding(num_symbols,
                                              self.char_embedding_dim)
        self.cnn_encoder = cnn_encoder(
            filter_widths=self.filter_widths,
            num_filters_per_width=self.num_filters_per_width,
            char_embedding_dim=self.char_embedding_dim)
        #seq_len = self.seq_len)

        self.decoder_hidden_size = len(
            self.filter_widths) * self.num_filters_per_width
        self.rnn_encoder = rnn_encoder(hidden_size=self.decoder_hidden_size)

        # decoder embedding dim dictated by output dim of encoder
        self.decoder_embedding = nn.Embedding(num_symbols,
                                              self.decoder_hidden_size)
        self.attention_decoder = AttnDecoderRNN(
            num_symbols=num_symbols,
            hidden_size=self.decoder_hidden_size,
            output_size=self.seq_len // self.pooling_stride)

        # if use_cuda:
        #       self.cnn_encoder = self.cnn_encoder.cuda()
        #       self.rnn_encoder = self.rnn_encoder.cuda()
        #       self.attention_decoder = self.attention_decoder.cuda()

        self.criterion = criterion

    def encode(self, data, seq_len, collect_filters=False):
        encoder_embedded = self.encoder_embedding(data).unsqueeze(1).transpose(
            2, 3)
        encoded = self.cnn_encoder.forward(encoder_embedded, self.seq_len,
                                           collect_filters)
        encoded = encoded.squeeze(2)

        encoder_hidden = self.rnn_encoder.initHidden()
        encoder_outputs = Variable(
            torch.zeros(64, seq_len // self.pooling_stride,
                        2 * self.decoder_hidden_size))
        if self.use_cuda:
            encoder_outputs = encoder_outputs.cuda()
            encoder_hidden = encoder_hidden.cuda()

        for symbol_ind in range(self.seq_len //
                                self.pooling_stride):  #self.rnn_emits_len):
            output, encoder_hidden = self.rnn_encoder.forward(
                encoded[:, :, symbol_ind], encoder_hidden)
            #print(output.data.shape) # (81, 64, 128)
            encoder_outputs[:, symbol_ind, :] = output[0]
        return encoder_outputs, encoder_hidden

    def decode(self, noisy_data, target_data, encoder_hidden, encoder_outputs,
               seq_len):
        loss = 0
        decoder_hidden = encoder_hidden
        #print(target_data.data.shape)
        for amino_acid_index in range(self.seq_len):
            target_amino_acid = target_data[:, :, amino_acid_index]  #.long()
            decoder_input = noisy_data.data[:, amino_acid_index].unsqueeze(
                1)  #.transpose(0,1)
            decoder_embedded = self.decoder_embedding(decoder_input)

            # # current symbol, current hidden state, outputs from encoder
            decoder_output, decoder_hidden, attn_weights = self.attention_decoder.forward(
                decoder_embedded, decoder_hidden, encoder_outputs,
                self.seq_len // self.pooling_stride)
            #print(decoder_output.data.shape, target_amino_acid.data.shape)   # torch.Size([64, 23])

            loss += self.criterion(decoder_output, Variable(target_amino_acid))
        return loss


# preliminary model
# class cnn_autoencoder(nn.Module):
#       def __init__(self):
#             super(cnn_autoencoder, self).__init__()
#             self.encoder = cnn_encoder()
#             self.decoder = cnn_decoder()
#             self.embedding = nn.Embedding(22, 4)

#       def encode(self, data):
#             char_embeddings = self.embedding(data).unsqueeze(1).transpose(2,3)
#             encoded, unpool_indices = self.encoder.forward(char_embeddings)
#             return encoded, unpool_indices

#       def decode(self, data, unpool_indices):
#             reconstructed = self.decoder.forward(data, unpool_indices)
#             return reconstructed
class CharLevel_autoencoder(nn.Module):
    def __init__(self, criterion, num_symbols, use_cuda):  #, seq_len):
        super(CharLevel_autoencoder, self).__init__()
        self.char_embedding_dim = 128
        self.pooling_stride = 5
        self.seq_len = 300
        self.num_symbols = num_symbols
        self.use_cuda = use_cuda

        self.filter_widths = list(range(1, 9))
        self.num_filters_per_width = [150, 150, 200, 200, 250, 250, 250, 250]

        self.encoder_embedding = nn.Embedding(num_symbols,
                                              self.char_embedding_dim)
        self.cnn_encoder = cnn_encoder(
            filter_widths=self.filter_widths,
            num_filters_per_width=self.num_filters_per_width,
            char_embedding_dim=self.char_embedding_dim,
            use_cuda=use_cuda)

        self.decoder_hidden_size = int(
            np.sum(np.array(self.num_filters_per_width)))
        self.rnn_encoder = rnn_encoder(hidden_size=self.decoder_hidden_size)

        # decoder embedding dim dictated by output dim of encoder
        self.decoder_embedding = nn.Embedding(num_symbols,
                                              self.decoder_hidden_size)
        self.attention_decoder = AttnDecoderRNN(
            num_symbols=num_symbols,
            hidden_size=self.decoder_hidden_size,
            output_size=self.seq_len // self.pooling_stride)

        self.criterion = criterion

    def encode(self, data, seq_len):
        encoder_embedded = self.encoder_embedding(data).unsqueeze(1).transpose(
            2, 3)
        encoded = self.cnn_encoder.forward(encoder_embedded, self.seq_len)
        encoded = encoded.squeeze(2)

        encoder_hidden = self.rnn_encoder.initHidden()
        encoder_outputs = Variable(
            torch.zeros(64, seq_len // self.pooling_stride,
                        2 * self.decoder_hidden_size))
        if self.use_cuda:
            encoder_outputs = encoder_outputs.cuda()
            encoder_hidden = encoder_hidden.cuda()

        for symbol_ind in range(self.seq_len //
                                self.pooling_stride):  #self.rnn_emits_len):
            output, encoder_hidden = self.rnn_encoder.forward(
                encoded[:, :, symbol_ind], encoder_hidden)
            #print(output.data.shape) # (81, 64, 128)
            encoder_outputs[:, symbol_ind, :] = output[0]
        return encoder_outputs, encoder_hidden

    def decode(self, target_data, decoder_hidden, encoder_outputs, i):
        use_teacher_forcing = True if random.random() < 0.7 else False
        if type(
                i
        ) != bool:  # given batch  index, then eval mode, no teacher forcing
            use_teacher_forcing = False
        #print(use_teacher_forcing)

        output = []
        # SOS token = 32 after encoding it
        input_embedded = Variable(torch.LongTensor([32]).repeat(64),
                                  requires_grad=False)
        if self.use_cuda:
            input_embedded = input_embedded.cuda()
        input_embedded = self.decoder_embedding(input_embedded)

        for symbol_index in range(self.seq_len):
            # # current symbol, current hidden state, outputs from encoder
            decoder_output, decoder_hidden, attn_weights = self.attention_decoder.forward(
                input_embedded, decoder_hidden, encoder_outputs)
            output.append(decoder_output)

            if use_teacher_forcing:
                input_symbol = Variable(target_data[:, symbol_index],
                                        requires_grad=False)
                if self.use_cuda:
                    input_symbol = input_symbol.cuda()

            else:
                values, input_symbol = decoder_output.max(1)
            input_embedded = self.decoder_embedding(input_symbol)

        actual_sentence_mask = torch.ne(target_data, 31).byte()
        threeD_mask = actual_sentence_mask.unsqueeze(2).repeat(
            1, 1, 125)  #.transpose()

        #print(actual_sentence_mask.shape, threeD_mask.shape)
        predicted = torch.stack(output, dim=1)

        if type(i) != bool:
            values, indices = predicted.max(2)
            print(indices.data.shape)
            pickle.dump(indices.data.numpy(),
                        open("./data/%s_predicted.p" % (i), "wb"),
                        protocol=4)

        #print(predicted.data.shape, target_data.shape)
        if self.use_cuda:
            target_data, actual_sentence_mask, threeD_mask = target_data.cuda(
            ), actual_sentence_mask.cuda(), threeD_mask.cuda()

        masked_target = torch.masked_select(target_data, actual_sentence_mask)
        predicted = predicted.masked_select(Variable(threeD_mask), )
        predicted = predicted.view(-1, 125)
        loss = self.criterion(predicted, Variable(masked_target, ))

        return loss