Пример #1
0
    def build_model(self):
        if self.use_embeddings:
            self.embedding = nn.Embedding.from_pretrained(self.embedding_wts)
        else:
            self.embedding = nn.Embedding(self.vocab.n_words,
                                          self.embedding_dim)
        self.encoders = []
        self.encoder_optimizers = []

        # Note: No embeddings used in the encoders
        for m in ['v', 's']:
            encoder = EncoderRNN(self.enc_input_dim[m], self.hidden_size,
                                 self.enc_n_layers, self.dropout, self.unit,
                                 m).to(self.device)
            encoder_optimizer = optim.Adam(encoder.parameters(), lr=self.lr)

            if self.modality == 'ss-vv':
                checkpoint = torch.load(self.pretrained_modality[m],
                                        map_location=self.device)
                encoder.load_state_dict(checkpoint['en'])
                encoder_optimizer.load_state_dict(checkpoint['en_op'])
            self.encoders.append(encoder)
            self.encoder_optimizers.append(encoder_optimizer)
        self.decoder = DecoderRNN(self.attn_model, self.embedding_dim,
                                  self.hidden_size, self.vocab.n_words,
                                  self.unit, self.dec_n_layers, self.dropout,
                                  self.embedding).to(self.device)
        text_checkpoint = torch.load(self.pretrained_modality['t'],
                                     map_location=self.device)
        self.decoder.load_state_dict(text_checkpoint['de'])
        self.project_factor = self.encoders[0].project_factor
        self.latent2hidden = nn.Linear(self.latent_dim, self.hidden_size *
                                       self.project_factor).to(self.device)
        self.epoch = 0
Пример #2
0
    def __init__(self, **kwargs):
        dp = DataPreprocessor()
        file_name_formatted = dp.write_to_file()
        dc = DataCleaner(file_name_formatted)
        dc.clean_data_pipeline().trim_rare_words()
        self.data_loader = DataLoader(dc.vocabulary, dc.pairs)
        self.dp = dp
        self.dc = dc
        load_embedding = kwargs.get('pretrained_embedding', False)
        embedding_file = kwargs.get('pretrained_embedding_file', None)

        load_enc_dec = kwargs.get('pretrained_enc_dec', False)
        load_enc_file = kwargs.get('pretrained_enc_file', None)
        load_dec_file = kwargs.get('pretrained_dec_file', None)

        self.model_name = kwargs.get('model_name', 'cb_model')
        attn_model = kwargs.get('attention_type', 'dot')
        self.hidden_size = kwargs.get('hidden_size', 500)
        self.encoder_nr_layers = kwargs.get('enc_nr_layers', 2)
        self.decoder_nr_layers = kwargs.get('dec_nr_layers', 2)
        dropout = kwargs.get('dropout', 0.1)
        self.batch_size = kwargs.get('batch_size', 64)
        self.clip = kwargs.get('clip', 50.0)
        self.teacher_forcing_ratio = kwargs.get('teacher_forcing_ratio', 1.0)
        self.learning_rate = kwargs.get('lr', 0.0001)
        self.decoder_learning_ratio = kwargs.get('decoder_learning_ratio', 5.0)
        self.nr_iteration = kwargs.get('nr_iterations', 4000)
        self.print_every = kwargs.get('print_every', 1)
        self.save_every = 500
        self.embedding = nn.Embedding(self.dc.vocabulary.num_words, self.hidden_size)
        if load_embedding:
            self.embedding.load_state_dict(embedding_file)
        # Initialize encoder & decoder models
        encoder = EncoderRNN(self.hidden_size, self.embedding, self.encoder_nr_layers, dropout)
        decoder = DecoderRNN(
            attn_model,
            self.embedding,
            self.hidden_size,
            self.dc.vocabulary.num_words,
            self.decoder_nr_layers,
            dropout
        )

        if load_enc_dec:
            encoder.load_state_dict(load_enc_file)
            decoder.load_state_dict(load_dec_file)
        # Use appropriate device
        encoder = encoder.to(device)
        decoder = decoder.to(device)
        self.encoder = encoder
        self.decoder = decoder
        self.encoder_optimizer = optim.Adam(encoder.parameters(), lr=self.learning_rate)
        self.decoder_optimizer = optim.Adam(decoder.parameters(), lr=self.learning_rate * self.decoder_learning_ratio)
        return
Пример #3
0
def main():
    """primary entry
    """
    voc, pairs = loadPreparedData()
    print('Building encoder and decoder ...')
    # Initialize word embeddings
    #embedding = nn.Embedding(voc.num_words, params.hidden_size)
    embedding = nn.Embedding(voc.num_words, params.embedding_size)
    # Initialize encoder & decoder models
    encoder = EncoderRNN(embedding, params.hidden_size,
                         params.encoder_n_layers, params.dropout)
    decoder = LuongAttnDecoderRNN(params.attn_model, embedding,
                                  params.hidden_size, voc.num_words,
                                  params.decoder_n_layers, params.dropout)
    # Use appropriate device
    encoder = encoder.to(params.device)
    decoder = decoder.to(params.device)
    print('Models built and ready to go!')

    # Ensure dropout layers are in train mode
    encoder.train()
    decoder.train()

    # Initialize optimizers
    print('Building optimizers ...')
    encoder_optimizer = optim.Adam(encoder.parameters(),
                                   lr=params.learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=params.learning_rate *
                                   params.decoder_learning_ratio)

    # Run training iterations
    print("Starting Training!")
    trainIters(voc,
               pairs,
               encoder,
               decoder,
               encoder_optimizer,
               decoder_optimizer,
               embedding,
               params.encoder_n_layers,
               params.decoder_n_layers,
               params.save_dir,
               params.n_iteration,
               params.batch_size,
               params.print_every,
               params.save_every,
               params.clip,
               params.corpus_name,
               load_filename=None)
Пример #4
0
class FeatureAutoEncoderNetwork(Sequence2SequenceNetwork):
    # This autoencoder is to be used only for video and speech vectors
    # Use base Sequence2SequenceNetwork class for autoencoding text
    def build_model(self):
        # Note: no embedding used here
        self.encoder = EncoderRNN(self.enc_input_dim, self.hidden_size,
                                  self.enc_n_layers, self.dropout, self.unit,
                                  self.modality).to(self.device)

        self.encoder_optimizer = optim.Adam(self.encoder.parameters(),
                                            lr=self.lr)

        self.epoch = 0  # define here to add resume training feature

    def load_pretrained_model(self):
        if self.load_model_name:
            checkpoint = torch.load(self.load_model_name,
                                    map_location=self.device)
            print('Loaded {}'.format(self.load_model_name))
            self.epoch = checkpoint['epoch']
            self.encoder.load_state_dict(checkpoint['en'])
            self.encoder_optimizer.load_state_dict(checkpoint['en_op'])

    def train_model(self):
        best_score = 1e-200
        plot_losses = []
        print_loss_total = 0  # Reset every epoch

        start = time.time()
        saving_skipped = 0
        for epoch in range(self.epoch, self.n_epochs):
            random.shuffle(self.pairs)
            for iter in range(0, self.n_iters, self.batch_size):
                training_batch = batch2TrainData(
                    self.vocab, self.pairs[iter:iter + self.batch_size],
                    self.modality)

                if len(training_batch[1]) < self.batch_size:
                    print('skipped a batch..')
                    continue

                # Extract fields from batch
                input_variable, lengths, target_variable, \
                    tar_lengths = training_batch

                # Run a training iteration with the current batch
                loss = self.train(input_variable, lengths, target_variable,
                                  iter)
                self.writer.add_scalar('{}loss'.format(self.data_dir), loss,
                                       iter)

                print_loss_total += loss

            print_loss_avg = print_loss_total * self.batch_size / self.n_iters
            print_loss_total = 0
            print('Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, self.n_epochs,
                                                       print_loss_avg))

            if self.modality == 'tt':
                # evaluate and save the model
                curr_score = self.evaluate_all()
            else:  # ss, vv
                curr_score = print_loss_avg

            if curr_score > best_score:
                saving_skipped = 0
                best_score = curr_score
                self.save_model(epoch)

            saving_skipped += 1

            if self.use_scheduler and saving_skipped > 3:
                saving_skipped = 0
                new_lr = self.lr * 0.5
                print('Entered the dungeon...')
                if new_lr > self.lr_lower_bound:  # lower bound on lr
                    self.lr = new_lr
                    print('lr decreased to => {}'.format(self.lr))

    def train(self, input_variable, lengths, target_variable, iter):
        input_variable = input_variable.to(self.device)
        lengths = lengths.to(self.device)
        target_variable = target_variable.to(self.device)

        # Initialize variables
        loss = 0
        print_losses = []
        n_totals = 0

        # Forward pass through encoder
        encoder_outputs, encoder_hidden = self.encoder(input_variable, lengths)
        if self.unit == 'gru':
            latent = encoder_hidden
        else:
            (latent, cell_state) = encoder_hidden
        # reconstruct input from latent vector
        seq_len = input_variable.shape[0]
        self.latent2output = nn.Linear(self.latent_dim, self.enc_input_dim *
                                       seq_len).to(self.device)
        output = self.latent2output(latent)
        output = output.view(seq_len, self.batch_size, self.enc_input_dim)
        reconstructed_input = output

        loss = self.mean_square_error(reconstructed_input, target_variable)
        loss.backward()
        # Clip gradients: gradients are modified in place
        torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), self.clip)
        self.encoder_optimizer.step()
        return loss.item()

    def mean_square_error(self, inp, target):
        criterion = nn.MSELoss()
        inp = (inp.permute(1, 0, 2))
        target = (target.permute(1, 0, 2))
        return criterion(inp, target)

    def save_model(self, epoch):
        directory = self.save_dir
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save(
            {
                'epoch': epoch,
                'en': self.encoder.state_dict(),
                'en_op': self.encoder_optimizer.state_dict()
            }, '{}{}-{}-{}-{}.pth'.format(directory, self.model_code,
                                          self.modality, self.langs, epoch))
Пример #5
0
n_layers = 2
dropout_p = 0.05
teacher_forcing_ratio = .5
clip = 5.
criterion = nn.NLLLoss()

# Initialize models
encoder = EncoderRNN(input_lang.n_words, hidden_size, n_layers)
decoder = AttentionDecoderRNN(attn_model,
                              hidden_size,
                              output_lang.n_words,
                              n_layers,
                              dropout_p=dropout_p)

learning_rate = 1
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

# Load model parameters
encoder.load_state_dict(torch.load(
    './data/encoder_params_{}'.format(language)))
decoder.load_state_dict(torch.load(
    './data/decoder_params_{}'.format(language)))
decoder.attention.load_state_dict(
    torch.load('./data/attention_params_{}'.format(language)))

# Move models to GPU
# encoder.cuda()
# decoder.cuda()

Пример #6
0
    decoder.cuda()
FINE_TUNE = True
if FINE_TUNE:
    encoder.embedding.weight.requires_grad = True

print('='*100)
print('Model log:\n')
print(encoder)
print(decoder)
print('- Encoder input embedding requires_grad={}'.format(encoder.embedding.weight.requires_grad))
print('- Decoder input embedding requires_grad={}'.format(decoder.embedding.weight.requires_grad))
print('- Decoder output embedding requires_grad={}'.format(decoder.W_s.weight.requires_grad))
print('='*100 + '\n')

# Initialize optimizers (we can experiment different learning rates)
encoder_optim = optim.Adam([p for p in encoder.parameters() if p.requires_grad], lr=opts.learning_rate, weight_decay=opts.weight_decay)
decoder_optim = optim.Adam([p for p in decoder.parameters() if p.requires_grad], lr=opts.learning_rate, weight_decay=opts.weight_decay)




if not pred_test:
    # Start training
    from datetime import datetime
    # from tensorboardX import SummaryWriter
    # --------------------------
    # Configure tensorboard
    # --------------------------
    print("Start training......")
    model_name = 'seq2seq'
    datetime = ('%s' % datetime.now()).split('.')[0].replace(' ', '_')
Пример #7
0
class BiLSTMModel(nn.Module):
    def __init__(self):
        super(BiLSTMModel, self).__init__()
        self.base_rnn = None
        self.wd = None

    def init_model(self,
                   wd,
                   hidden_size,
                   e_layers,
                   d_layers,
                   base_rnn,
                   pretrained_embeddings=None,
                   dropout_p=0.1):

        self.base_rnn = base_rnn
        self.wd = wd
        self.dropout_p = dropout_p
        if pretrained_embeddings is True:
            print("Loading GloVe Embeddings ...")
            pretrained_embeddings = load_glove_embeddings(
                wd.word2index, hidden_size)

        self.encoder = EncoderRNN(wd.n_words,
                                  hidden_size,
                                  n_layers=e_layers,
                                  base_rnn=base_rnn,
                                  pretrained_embeddings=pretrained_embeddings)

        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(int(hidden_size * 8), int(hidden_size)),
            torch.nn.ReLU(), torch.nn.Dropout(dropout_p),
            torch.nn.Linear(int(hidden_size), 3), torch.nn.Softmax(dim=1))
        self.parameter_list = [
            self.encoder.parameters(),
            self.mlp.parameters()
        ]

        if USE_CUDA:
            self.encoder = self.encoder.cuda()
            self.mlp = self.mlp.cuda()

        return self

    def forward(self, batch, inference=False):
        # Convert batch from numpy to torch
        if inference is True:
            text_batch, text_lengths, hyp_batch, hyp_lengths = batch
        else:
            text_batch, text_lengths, hyp_batch, hyp_lengths, labels = batch
        batch_size = text_batch.size(1)

        # Pass the input batch through the encoder
        text_enc_fwd_outputs, text_enc_bkwd_outputs, text_encoder_hidden = self.encoder(
            text_batch, text_lengths)
        hyp_enc_fwd_outputs, hyp_enc_bkwd_outputs, hyp_encoder_hidden = self.encoder(
            hyp_batch, hyp_lengths)

        last_text_enc_fwd = text_enc_fwd_outputs[-1, :, :]
        last_text_enc_bkwd = text_enc_bkwd_outputs[0, :, :]
        last_text_enc = torch.cat((last_text_enc_fwd, last_text_enc_bkwd),
                                  dim=1)
        last_hyp_enc_fwd = hyp_enc_fwd_outputs[-1, :, :]
        last_hyp_enc_bkwd = hyp_enc_bkwd_outputs[0, :, :]
        last_hyp_enc = torch.cat((last_hyp_enc_fwd, last_hyp_enc_bkwd), dim=1)

        mult_feature, diff_feature = last_text_enc * last_hyp_enc, torch.abs(
            last_text_enc - last_hyp_enc)

        features = torch.cat(
            [last_text_enc, last_hyp_enc, mult_feature, diff_feature], dim=1)
        outputs = self.mlp(features)  # B x 3
        return outputs

    def get_loss_for_batch(self, batch):
        labels = batch[-1]
        outputs = self(batch)

        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(outputs, labels)

        return loss

    def torch_batch_from_numpy_batch(self, batch):
        batch = list(batch)

        variable_indices = [
            0, 2, 4
        ]  # tuple indices of variables need to be converted
        for i in variable_indices:
            var = Variable(torch.from_numpy(batch[i]))
            if USE_CUDA:
                var = var.cuda()
            batch[i] = var

        return batch

    # Trains on a single batch
    def train_batch(self, batch, tl_mode=False):
        self.train()

        batch = self.torch_batch_from_numpy_batch(batch)
        loss = self.get_loss_for_batch(batch)
        loss.backward()

        return loss.item()  #loss.data[0]

    def validate(self, batch):
        self.eval()

        batch = self.torch_batch_from_numpy_batch(batch)
        return self.get_loss_for_batch(batch).item()  #.data[0]

    def score(self, data):
        batch_size = 1
        batches = nli_batches(batch_size, data)

        total_correct = 0
        for batch in tqdm(batches):
            batch = self.torch_batch_from_numpy_batch(batch)
            num_correct = self._acc_for_batch(batch)
            total_correct += num_correct

        acc = total_correct / (len(batches) * batch_size)

        return acc

    def _acc_for_batch(self, batch):
        '''
        :param batch:
        :return: The number of correct predictions in a batch
        '''
        self.eval()

        outputs = self(batch)
        predictions = outputs.max(1)[1]

        labels = batch[-1]

        num_error = torch.nonzero(labels - predictions)
        num_correct = labels.size(0) - num_error.size(0)

        return num_correct

    def export_state(self, dir, label, epoch=-1):
        print("Saving models.")

        cwd = os.getcwd() + '/'

        enc_out = dir + ENC_1_FILE
        mlp_out = dir + MLP_FILE
        i2w_out = dir + I2W_FILE
        w2i_out = dir + W2I_FILE
        w2c_out = dir + W2C_FILE
        inf_out = dir + INF_FILE

        torch.save(self.encoder.state_dict(), enc_out)
        torch.save(self.mlp.state_dict(), mlp_out)

        i2w = open(i2w_out, 'wb')
        pickle.dump(self.wd.index2word, i2w)
        i2w.close()
        w2i = open(w2i_out, 'wb')
        pickle.dump(self.wd.word2index, w2i)
        w2i.close()
        w2c = open(w2c_out, 'wb')
        pickle.dump(self.wd.word2count, w2c)
        w2c.close()

        info = open(inf_out, 'w')
        using_lstm = 1 if self.base_rnn == nn.LSTM else 0
        info.write(
            str(self.encoder.hidden_size) + "\n" + str(self.encoder.n_layers) +
            "\n" + str(self.wd.n_words) + "\n" + str(using_lstm))
        if epoch > 0:
            info.write("\n" + str(epoch))
        info.close()

        files = [enc_out, mlp_out, i2w_out, w2i_out, w2c_out, inf_out]

        print("Bundling models")

        tf = tarfile.open(cwd + dir + label, mode='w')
        for file in files:
            tf.add(file)
        tf.close()

        for file in files:
            os.remove(file)

        print("Finished saving models.")

    def import_state(self, model_file, active_dir=TEMP_DIR, load_epoch=False):
        print("Loading models.")
        cwd = os.getcwd() + '/'
        tf = tarfile.open(model_file)

        # extract directly to current model directory
        for member in tf.getmembers():
            if member.isreg():
                member.name = os.path.basename(member.name)
                tf.extract(member, path=active_dir)

        info = open(active_dir + INF_FILE, 'r')
        lns = info.readlines()
        hidden_size, e_layers, n_words, using_lstm = [int(i) for i in lns[:4]]

        if load_epoch:
            epoch = int(lns[-1])

        i2w = open(cwd + TEMP_DIR + I2W_FILE, 'rb')
        w2i = open(cwd + TEMP_DIR + W2I_FILE, 'rb')
        w2c = open(cwd + TEMP_DIR + W2C_FILE, 'rb')
        i2w_dict = pickle.load(i2w)
        w2i_dict = pickle.load(w2i)
        w2c_dict = pickle.load(w2c)
        wd = WordDict(dicts=[w2i_dict, i2w_dict, w2c_dict, n_words])
        w2i.close()
        i2w.close()
        w2c.close()

        self.base_rnn = nn.LSTM if using_lstm == 1 else nn.GRU
        self.wd = wd
        self.encoder = EncoderRNN(wd.n_words,
                                  hidden_size,
                                  n_layers=e_layers,
                                  base_rnn=self.base_rnn)
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(int(hidden_size * 8), int(hidden_size)),
            torch.nn.ReLU(), torch.nn.Dropout(0.1),
            torch.nn.Linear(int(hidden_size), 3), torch.nn.Softmax(dim=1))
        if not USE_CUDA:
            self.encoder.load_state_dict(
                torch.load(cwd + TEMP_DIR + ENC_1_FILE,
                           map_location=lambda storage, loc: storage))
            self.mlp.load_state_dict(
                torch.load(cwd + TEMP_DIR + MLP_FILE,
                           map_location=lambda storage, loc: storage))
        else:
            self.encoder.load_state_dict(
                torch.load(cwd + TEMP_DIR + ENC_1_FILE))
            self.mlp.load_state_dict(torch.load(cwd + TEMP_DIR + MLP_FILE))
            self.encoder = self.encoder.cuda()
            self.mlp = self.mlp.cuda()

        self.encoder.eval()
        self.mlp.eval()

        self.parameter_list = [
            self.encoder.parameters(),
            self.mlp.parameters()
        ]
        tf.close()

        print("Loaded models.")

        if load_epoch:
            return self, epoch
        else:
            return self

    def torch_batch_from_numpy_batch_without_label(self, batch):
        batch = list(batch)

        variable_indices = [0, 2]
        for i in variable_indices:
            var = Variable(torch.from_numpy(batch[i]))
            if USE_CUDA:
                var = var.cuda()
            batch[i] = var

        return batch

    def predict(self, data):
        batch_size = 1
        batches = nli_batches_without_label(batch_size, data)

        predictions = []
        for batch in tqdm(batches):
            batch = self.torch_batch_from_numpy_batch_without_label(batch)
            outputs = self(batch, inference=True)
            pred = outputs.max(1)[1]
            predictions.append(pred)

        return torch.cat(predictions)

    def add_new_vocabulary(self, genre):
        old_vocab_size = self.wd.n_words
        print("Previous vocabulary size: " + str(old_vocab_size))

        train_set = nli_preprocessor.get_multinli_text_hyp_labels(
            genre=genre
        )  #nli_preprocessor.get_multinli_training_set(max_lines=args.max_lines)
        matched_val_set = nli_preprocessor.get_multinli_matched_val_set(
        )  #genre_val_set(genre)

        unmerged_sentences = []
        for data in [train_set, matched_val_set]:
            unmerged_sentences.extend([data["text"], data["hyp"]])
        all_sentences = list(itertools.chain.from_iterable(unmerged_sentences))

        for line in all_sentences:
            self.wd.add_sentence(line)

        print("New vocabulary size: " + str(self.wd.n_words))

        print("Extending the Embedding layer with new vocabulary...")
        num_new_words = self.wd.n_words - old_vocab_size
        self.encoder.extend_embedding_layer(self.wd.word2index, num_new_words)

        self.new_vocab_size = num_new_words

    def freeze_source_params(self):
        for name, param in self.named_parameters():
            if "rnn" in name:
                param.requires_grad = False
            if ("M_k" in name or "M_v" in name) and "target_4" not in name:
                param.requires_grad = False
        for name, param in self.named_parameters():
            if param.requires_grad is True:
                print(name)
class Seq2Seq(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, learning_rate,
                 teacher_forcing_ratio, device):
        super(Seq2Seq, self).__init__()

        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.device = device

        self.encoder = EncoderRNN(input_size, hidden_size)
        self.decoder = AttnDecoderRNN(hidden_size, output_size)

        self.encoder_optimizer = optim.SGD(self.encoder.parameters(),
                                           lr=learning_rate)
        self.decoder_optimizer = optim.SGD(self.decoder.parameters(),
                                           lr=learning_rate)

        self.criterion = nn.NLLLoss()

    def train(self,
              input_tensor,
              target_tensor,
              max_length=constants.MAX_LENGTH):
        encoder_hidden = self.encoder.initHidden()

        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()

        input_length = input_tensor.size(0)
        target_length = target_tensor.size(0)

        encoder_outputs = torch.zeros(max_length + 1,
                                      self.encoder.hidden_size,
                                      device=self.device)

        loss = 0

        for ei in range(input_length):
            encoder_output, encoder_hidden = self.encoder(
                input_tensor[ei], encoder_hidden)
            encoder_outputs[ei] = encoder_output[0, 0]

        decoder_input = torch.tensor([[constants.SOS_TOKEN]],
                                     device=self.device)
        decoder_hidden = encoder_hidden

        use_teacher_forcing = True if np.random.random(
        ) < self.teacher_forcing_ratio else False

        if use_teacher_forcing:
            # Teacher forcing: feed the target as the next input
            for di in range(target_length):
                decoder_output, decoder_hidden, decoder_attention = self.decoder(
                    decoder_input, decoder_hidden, encoder_outputs)
                loss += self.criterion(decoder_output, target_tensor[di])
                decoder_input = target_tensor[di]  # Teacher forcing
        else:
            # Without teacher forcing: use its own prediction as the next input
            for di in range(target_length):
                decoder_output, decoder_hidden, decoder_attention = self.decoder(
                    decoder_input, decoder_hidden, encoder_outputs)
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze().detach(
                )  # detach from history as input

                loss += self.criterion(decoder_output, target_tensor[di])

                if decoder_input.item() == constants.EOS_TOKEN:
                    break

        loss.backward()

        self.encoder_optimizer.step()
        self.decoder_optimizer.step()

        return loss.item() / target_length

    def trainIters(self, env, evaluator):
        start_total_time = time.time() - env.total_training_time
        start_epoch_time = time.time()  # Reset every LOG_EVERY iterations
        start_train_time = time.time()  # Reset every LOG_EVERY iterations
        total_loss = 0  # Reset every LOG_EVERY iterations

        for iter in range(env.iters_completed + 1, constants.NUM_ITER + 1):
            row = env.train_methods.iloc[np.random.randint(
                len(env.train_methods))]
            input_tensor = row['source']
            target_tensor = row['name']

            loss = self.train(input_tensor, target_tensor)
            total_loss += loss

            if iter % constants.LOG_EVERY == 0:
                log('Completed {} iterations'.format(iter))

                train_time_elapsed = time.time() - start_train_time

                log('Evaluating on validation set')
                start_eval_time = time.time()

                names = evaluator.evaluate(self)
                # save_dataframe(names, constants.VALIDATION_NAMES_FILE)

                eval_time_elapsed = time.time() - start_eval_time

                env.history = env.history.append(
                    {
                        'Loss': total_loss / constants.LOG_EVERY,
                        'BLEU': names['BLEU'].mean(),
                        'ROUGE': names['ROUGE'].mean(),
                        'F1': names['F1'].mean(),
                        'num_names': len(names['GeneratedName'].unique())
                    },
                    ignore_index=True)

                epoch_time_elapsed = time.time() - start_epoch_time
                total_time_elapsed = time.time() - start_total_time

                env.total_training_time = total_time_elapsed

                history_last_row = env.history.iloc[-1]

                log_dict = OrderedDict([
                    ("Iteration", '{}/{} ({:.1f}%)'.format(
                        iter, constants.NUM_ITER,
                        iter / constants.NUM_ITER * 100)),
                    ("Average loss", history_last_row['Loss']),
                    ("Average BLEU", history_last_row['BLEU']),
                    ("Average ROUGE", history_last_row['ROUGE']),
                    ("Average F1", history_last_row['F1']),
                    ("Unique names", int(history_last_row['num_names'])),
                    ("Epoch time", time_str(epoch_time_elapsed)),
                    ("Training time", time_str(train_time_elapsed)),
                    ("Evaluation time", time_str(eval_time_elapsed)),
                    ("Total training time", time_str(total_time_elapsed))
                ])

                write_training_log(log_dict, constants.TRAIN_LOG_FILE)
                plot_and_save_histories(env.history)

                env.iters_completed = iter
                env.save_train()

                # Reseting counters
                total_loss = 0
                start_epoch_time = time.time()
                start_train_time = time.time()

    def forward(self,
                input_tensor,
                max_length=constants.MAX_LENGTH,
                return_attention=False):
        encoder_hidden = self.encoder.initHidden()

        input_length = input_tensor.size(0)

        encoder_outputs = torch.zeros(max_length + 1,
                                      self.encoder.hidden_size,
                                      device=self.device)

        for ei in range(input_length):
            encoder_output, encoder_hidden = self.encoder(
                input_tensor[ei], encoder_hidden)
            encoder_outputs[ei] = encoder_output[0, 0]

        decoder_input = torch.tensor([[constants.SOS_TOKEN]],
                                     device=self.device)
        decoder_hidden = encoder_hidden

        decoded_words = []
        attention_vectors = []

        for di in range(max_length):
            decoder_output, decoder_hidden, decoder_attention = self.decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            topv, topi = decoder_output.data.topk(1)

            decoded_words.append(topi.item())
            attention_vectors.append(decoder_attention.tolist()[0])

            if decoded_words[-1] == constants.EOS_TOKEN:
                break

            decoder_input = topi.squeeze().detach()

        if return_attention:
            return decoded_words, attention_vectors
        else:
            return decoded_words
Пример #9
0
def pretrain():
    # Parse command line arguments
    argparser = argparse.ArgumentParser()

    # train
    argparser.add_argument('--mode',
                           '-m',
                           choices=('pretrain', 'adversarial', 'inference'),
                           type=str,
                           required=True)
    argparser.add_argument('--batch_size', '-b', type=int, default=168)
    argparser.add_argument('--num_epoch', '-e', type=int, default=10)
    argparser.add_argument('--print_every', type=int, default=100)
    argparser.add_argument('--use_cuda', default=True)
    argparser.add_argument('--g_learning_rate',
                           '-glr',
                           type=float,
                           default=0.001)
    argparser.add_argument('--d_learning_rate',
                           '-dlr',
                           type=float,
                           default=0.001)

    # resume
    argparser.add_argument('--resume', action='store_true', dest='resume')
    argparser.add_argument('--resume_dir', type=str)
    argparser.add_argument('--resume_epoch', type=int)

    # save
    argparser.add_argument('--exp_dir', type=str, required=True)

    # model
    argparser.add_argument('--emb_dim', type=int, default=128)
    argparser.add_argument('--hidden_dim', type=int, default=256)
    argparser.add_argument('--dropout_rate', '-drop', type=float, default=0.5)
    argparser.add_argument('--n_layers', type=int, default=1)
    argparser.add_argument('--response_max_len', type=int, default=15)

    # data
    argparser.add_argument('--train_query_file',
                           '-tqf',
                           type=str,
                           required=True)
    argparser.add_argument('--train_response_file',
                           '-trf',
                           type=str,
                           required=True)
    argparser.add_argument('--valid_query_file',
                           '-vqf',
                           type=str,
                           required=True)
    argparser.add_argument('--valid_response_file',
                           '-vrf',
                           type=str,
                           required=True)
    argparser.add_argument('--vocab_file', '-vf', type=str, default='')
    argparser.add_argument('--max_vocab_size', '-mv', type=int, default=100000)

    args = argparser.parse_args()

    # set up the output directory
    exp_dirname = os.path.join(args.exp_dir, args.mode,
                               time.strftime("%Y-%m-%d-%H-%M-%S"))
    os.makedirs(exp_dirname)

    # set up the logger
    tqdm_logging.config(logger,
                        os.path.join(exp_dirname, 'train.log'),
                        mode='w',
                        silent=False,
                        debug=True)

    if not args.vocab_file:
        logger.info("no vocabulary file")
        build_vocab(args.train_query_file,
                    args.train_response_file,
                    seperated=True)
        sys.exit()
    else:
        vocab, rev_vocab = load_vocab(args.vocab_file,
                                      max_vocab=args.max_vocab_size)

    vocab_size = len(vocab)

    word_embeddings = nn.Embedding(vocab_size,
                                   args.emb_dim,
                                   padding_idx=SYM_PAD)
    E = EncoderRNN(vocab_size,
                   args.emb_dim,
                   args.hidden_dim,
                   args.n_layers,
                   args.dropout_rate,
                   bidirectional=True,
                   variable_lengths=True)
    G = Generator(vocab_size,
                  args.response_max_len,
                  args.emb_dim,
                  2 * args.hidden_dim,
                  args.n_layers,
                  dropout_p=args.dropout_rate)

    if args.use_cuda:
        word_embeddings.cuda()
        E.cuda()
        G.cuda()

    loss_func = nn.NLLLoss(size_average=False)
    params = list(word_embeddings.parameters()) + list(E.parameters()) + list(
        G.parameters())
    opt = torch.optim.Adam(params, lr=args.g_learning_rate)

    logger.info('----------------------------------')
    logger.info('Pre-train a neural conversation model')
    logger.info('----------------------------------')

    logger.info('Args:')
    logger.info(str(args))

    logger.info('Vocabulary from ' + args.vocab_file)
    logger.info('vocabulary size: %d' % vocab_size)
    logger.info('Loading text data from ' + args.train_query_file + ' and ' +
                args.train_response_file)

    # resume training from other experiment
    if args.resume:
        assert args.resume_epoch >= 0, 'If resume training, please assign resume_epoch'
        reload_model(args.resume_dir, args.resume_epoch, word_embeddings, E, G)
        start_epoch = args.resume_epoch + 1
    else:
        start_epoch = 0

    # dump args
    with open(os.path.join(exp_dirname, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    for e in range(start_epoch, args.num_epoch):
        logger.info('---------------------training--------------------------')
        train_data_generator = batcher(args.batch_size, args.train_query_file,
                                       args.train_response_file)
        logger.info("Epoch: %d/%d" % (e, args.num_epoch))
        step = 0
        total_loss = 0.0
        total_valid_char = []
        cur_time = time.time()
        while True:
            try:
                post_sentences, response_sentences = train_data_generator.next(
                )
            except StopIteration:
                # save model
                save_model(exp_dirname, e, word_embeddings, E, G)
                # evaluation
                eval(args.valid_query_file, args.valid_response_file,
                     args.batch_size, word_embeddings, E, G, loss_func,
                     args.use_cuda, vocab, args.response_max_len)
                break

            post_ids = [sentence2id(sent, vocab) for sent in post_sentences]
            response_ids = [
                sentence2id(sent, vocab) for sent in response_sentences
            ]
            posts_var, posts_length = padding_inputs(post_ids, None)
            responses_var, responses_length = padding_inputs(
                response_ids, args.response_max_len)
            # sort by post length
            posts_length, perms_idx = posts_length.sort(0, descending=True)
            posts_var = posts_var[perms_idx]
            responses_var = responses_var[perms_idx]
            responses_length = responses_length[perms_idx]

            # 在sentence后面加eos
            references_var = torch.cat([
                responses_var,
                Variable(torch.zeros(responses_var.size(0), 1).long(),
                         requires_grad=False)
            ],
                                       dim=1)
            for idx, length in enumerate(responses_length):
                references_var[idx, length] = SYM_EOS

            # show case
            #for p, r, ref in zip(posts_var.data.numpy()[:10], responses_var.data.numpy()[:10], references_var.data.numpy()[:10]):
            #    print ''.join(id2sentence(p, rev_vocab))
            #    print ''.join(id2sentence(r, rev_vocab))
            #    print ''.join(id2sentence(ref, rev_vocab))
            #    print

            if args.use_cuda:
                posts_var = posts_var.cuda()
                responses_var = responses_var.cuda()
                references_var = references_var.cuda()

            embedded_post = word_embeddings(posts_var)
            embedded_response = word_embeddings(responses_var)

            _, dec_init_state = E(embedded_post,
                                  input_lengths=posts_length.numpy())
            log_softmax_outputs = G.supervise(
                embedded_response, dec_init_state,
                word_embeddings)  # [B, T, vocab_size]

            outputs = log_softmax_outputs.view(-1, vocab_size)
            mask_pos = mask(references_var).view(-1).unsqueeze(-1)
            masked_output = outputs * (mask_pos.expand_as(outputs))
            loss = loss_func(masked_output,
                             references_var.view(-1)) / (posts_var.size(0))

            opt.zero_grad()
            loss.backward()
            opt.step()

            total_loss += loss * (posts_var.size(0))
            total_valid_char.append(mask_pos)

            if step % args.print_every == 0:
                total_loss_val = total_loss.cpu().data.numpy()[0]
                total_valid_char_val = torch.sum(
                    torch.cat(total_valid_char, dim=1)).cpu().data.numpy()[0]
                logger.info(
                    'Step %5d: (per word) training perplexity %.2f (%.1f iters/sec)'
                    % (step, math.exp(total_loss_val / total_valid_char_val),
                       args.print_every / (time.time() - cur_time)))
                total_loss = 0.0
                total_valid_char = []
                total_case_num = 0
                cur_time = time.time()
            step = step + 1
Пример #10
0
input_lang, output_lang, pairs = etl.prepare_data(args.language)

# Initialize models
encoder = EncoderRNN(input_lang.n_words, args.embedding_size, args.hidden_size,
                     args.n_layers, args.dropout)

decoder = AttentionDecoderRNN(output_lang.n_words, args.embedding_size,
                              args.hidden_size, args.attn_model, args.n_layers,
                              args.dropout)
# Move models to device
encoder = encoder.to(device)
decoder = decoder.to(device)

# Initialize optimizers and criterion
encoder_optimizer = optim.Adam(encoder.parameters(), lr=args.lr)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=args.lr)
criterion = nn.NLLLoss()

# Keep track of time elapsed and running averages
start = time.time()
plot_losses = []
print_loss_total = 0  # Reset every print_every
plot_loss_total = 0  # Reset every plot_every

# Begin training
for epoch in range(1, args.n_epochs + 1):
    # Get training data for this cycle
    training_pair = etl.tensor_from_pair(random.choice(pairs), input_lang,
                                         output_lang)
    input = training_pair[0]
Пример #11
0
def train(input_sentences, output_sentences, input_vocab, output_vocab,
          input_reverse, output_reverse, hy, writer):
    dataset = NMTDataset(input_sentences, output_sentences, input_vocab,
                         output_vocab, input_reverse, output_reverse)
    loader = DataLoader(dataset,
                        batch_size=hy.batch_size,
                        shuffle=True,
                        drop_last=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_vocab_size = len(input_vocab.keys())
    output_vocab_size = len(output_vocab.keys())

    encoder = EncoderRNN(input_vocab_size, hy.embedding_size, hy.hidden_size,
                         hy.rnn_layers, hy.bidirectional, device)
    decoder = DecoderRNN(output_vocab_size, hy.embedding_size, hy.hidden_size,
                         hy.rnn_layers, hy.bidirectional, device)

    loss_function = nn.CrossEntropyLoss().to(device)
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=hy.lr)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=hy.lr)

    n_iterations = 0
    loss_history = []
    training_accuracy = 0.

    encoder.train()
    decoder.train()

    for epoch in range(1, hy.num_epochs + 1):
        for encoder_input, decoder_input, decoder_output in tqdm(
                loader, desc="{}/{}".format(epoch, hy.num_epochs)):
            encoder_input = encoder_input.to(device)
            decoder_input = decoder_input.to(device)
            decoder_output = decoder_output.to(device)

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            _, encoder_hidden = encoder(encoder_input)
            logits = decoder(decoder_input, encoder_hidden)

            loss = loss_function(
                logits.view(hy.batch_size * decoder_output.shape[1], -1),
                decoder_output.view(-1))

            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()

            writer.add_scalar("TrainingLoss", loss.item(), n_iterations)
            n_iterations = n_iterations + 1
            loss_history.append(loss.item())

        training_accuracy = compute_model_accuracy(encoder, decoder, loader,
                                                   device, epoch, writer)
        torch.save(encoder.state_dict(),
                   "saved_runs/encoder_{}_weights.pt".format(epoch))
        torch.save(decoder.state_dict(),
                   "saved_runs/decoder_{}_weights.pt".format(epoch))

    return loss_history, training_accuracy
Пример #12
0
class Sequence2SequenceNetwork(object):
    def __init__(self, config):
        self.init_writer()
        self.load_configuration(config)
        self.load_vocabulary()
        self.prepare_data()
        self.build_model()
        self.load_pretrained_model()
        self.train_model()
        self.save_model(self.n_epochs)
        self.evaluate_all()
        self.close_writer()

    def init_writer(self):
        self.writer = SummaryWriter()

    def load_configuration(self, config):
        # Load configuration
        self.iter_num = 0
        self.lr = config['lr']
        self.gpu = config['gpu']
        self.unit = config['unit']
        self.clip = config['clip']
        self.beta1 = config['beta1']
        self.beta2 = config['beta2']
        self.langs = config['langs']
        self.fusion = config['fusion']
        self.log_tb = config['log_tb']
        self.epsilon = config['epsilon']
        self.attn_model = config['attn']
        self.dropout = config['dropout']
        self.emb_mode = config['emb_mode']
        self.save_dir = config['save_dir']
        self.data_dir = config['data_dir']
        self.n_epochs = config['n_epochs']
        self.SOS_TOKEN = config['SOS_TOKEN']
        self.EOS_TOKEN = config['EOS_TOKEN']
        self.MAX_LENGTH = config['MAX_LENGTH']
        self.latent_dim = config['latent_dim']
        self.batch_size = config['batch_size']
        self.model_code = config['model_code']
        self.vocab_path = config['vocab_path']
        self.hidden_size = config['hidden_size']
        self.use_cuda = torch.cuda.is_available()
        self.log_tb_every = config['log_tb_every']
        self.enc_n_layers = config['enc_n_layers']
        self.dec_n_layers = config['dec_n_layers']
        self.dec_learning_ratio = config['dec_lr']
        self.bidirectional = config['bidirectional']
        self.enc_input_dim = config['enc_input_dim']
        self.embedding_dim = config['embedding_dim']
        self.use_scheduler = config['use_scheduler']
        self.use_embeddings = config['use_embeddings']
        self.lr_lower_bound = config['lr_lower_bound']
        self.teacher_forcing_ratio = config['tf_ratio']
        self.load_model_name = config['load_model_name']
        self.modality = config[
            'modalities']  # no splitting as it's not multimodal case
        if self.modality in ['ss-vv', 'v-s']:
            self.pretrained_modality = config['pretrained_modality']
        self.generate_word_embeddings = config['generate_word_embeddings']
        self.device = torch.device(
            'cuda:{}'.format(self.gpu) if self.use_cuda else 'cpu')

    def load_vocabulary(self):
        try:
            with open(self.vocab_path, 'rb') as f:
                self.vocab = pickle.load(f)
        except FileNotFoundError as e:  # build vocab if it doesn't exist
            self.vocab = buildVocab()

    def prepare_data(self):
        # Note: The below workaround is used a lot and doing so is okay
        # because this script would only be run for unimodal cases
        self.pairs = prepareData(self.langs, [self.modality])[self.modality]
        num_pairs = len(self.pairs)
        self.pairs = self.pairs[:self.batch_size *
                                (num_pairs // self.batch_size)]
        random.shuffle(self.pairs)
        self.n_iters = len(self.pairs)
        print('\nLoading test data pairs')
        self.test_pairs = prepareData(self.langs, [self.modality],
                                      train=False)[self.modality]
        random.shuffle(self.test_pairs)
        print(random.choice(self.pairs))
        if self.use_embeddings:
            if self.generate_word_embeddings:
                self.embedding_wts = generateWordEmbeddings(
                    self.vocab, self.emb_mode)
            else:
                self.embedding_wts = loadWordEmbeddings(self.emb_mode)

    def build_model(self):
        if self.use_embeddings:
            self.embedding = nn.Embedding.from_pretrained(self.embedding_wts)
        else:
            self.embedding = nn.Embedding(self.vocab.n_words,
                                          self.embedding_dim)

        if self.modality == 't':  # Need embedding only for t2t mode
            self.encoder = EncoderRNN(self.embedding_dim,
                                      self.hidden_size,
                                      self.enc_n_layers,
                                      self.dropout,
                                      self.unit,
                                      self.modality,
                                      self.embedding,
                                      fusion_or_unimodal=True).to(self.device)
        else:
            # Note: no embedding used here
            self.encoder = EncoderRNN(self.enc_input_dim,
                                      self.hidden_size,
                                      self.enc_n_layers,
                                      self.dropout,
                                      self.unit,
                                      self.modality,
                                      fusion_or_unimodal=True).to(self.device)

        self.decoder = DecoderRNN(self.attn_model, self.embedding_dim,
                                  self.hidden_size, self.vocab.n_words,
                                  self.unit, self.dec_n_layers, self.dropout,
                                  self.embedding).to(self.device)
        self.encoder_optimizer = optim.Adam(self.encoder.parameters(),
                                            lr=self.lr)
        self.decoder_optimizer = optim.Adam(self.decoder.parameters(),
                                            lr=self.lr *
                                            self.dec_learning_ratio)

        self.epoch = 0  # define here to add resume training feature
        self.project_factor = self.encoder.project_factor
        self.latent2hidden = nn.Linear(self.latent_dim, self.hidden_size *
                                       self.project_factor).to(self.device)

    def load_pretrained_model(self):
        if self.load_model_name:
            checkpoint = torch.load(self.load_model_name,
                                    map_location=self.device)
            print('Loaded {}'.format(self.load_model_name))
            self.epoch = checkpoint['epoch']
            self.encoder.load_state_dict(checkpoint['en'])
            self.decoder.load_state_dict(checkpoint['de'])
            self.encoder_optimizer.load_state_dict(checkpoint['en_op'])
            self.decoder_optimizer.load_state_dict(checkpoint['de_op'])
            self.embedding.load_state_dict(checkpoint['embedding'])

    def train_model(self):
        best_score = 1e-200
        print_loss_total = 0  # Reset every epoch

        saving_skipped = 0
        for epoch in range(self.epoch, self.n_epochs):
            incomplete = False
            for iter in range(0, self.n_iters, self.batch_size):
                pairs = self.pairs[iter:iter + self.batch_size]
                # Skip incomplete batch
                if len(pairs) < self.batch_size:
                    incomplete = True
                    continue
                training_batch = batch2TrainData(self.vocab, pairs,
                                                 self.modality)

                # Extract fields from batch
                input_variable, lengths, target_variable, \
                    mask, max_target_len, _ = training_batch

                if incomplete:
                    break

                # Run a training iteration with the current batch
                loss = self.train(input_variable, lengths, target_variable,
                                  mask, max_target_len, iter)
                self.writer.add_scalar('{}loss'.format(self.data_dir), loss,
                                       iter)

                print_loss_total += loss

            print_loss_avg = print_loss_total * self.batch_size / self.n_iters
            print_loss_total = 0
            print('Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, self.n_epochs,
                                                       print_loss_avg))

            # evaluate and save the model
            curr_score = self.evaluate_all()

            self.writer.add_scalar('{}bleu_score'.format(self.data_dir),
                                   curr_score)

            if curr_score > best_score:
                saving_skipped = 0
                best_score = curr_score
                self.save_model(epoch)

            saving_skipped += 1

            if self.use_scheduler and saving_skipped > 3:
                saving_skipped = 0
                new_lr = self.lr * 0.5
                print('Entered the dungeon...')
                if new_lr > self.lr_lower_bound:  # lower bound on lr
                    self.lr = new_lr
                    print('lr decreased to => {}'.format(self.lr))

    def train(self, input_variable, lengths, target_variable, mask,
              max_target_len, iter):
        self.encoder.train()
        self.decoder.train()
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()

        input_variable = input_variable.to(self.device)
        lengths = lengths.to(self.device)
        target_variable = target_variable.to(self.device)
        mask = mask.to(self.device)

        # Initialize variables
        loss = 0
        print_losses = []
        n_totals = 0

        # Forward pass through encoder
        encoder_outputs, encoder_hidden = self.encoder(input_variable, lengths)

        # Create initial decoder input (start with SOS tokens for each sentence)
        decoder_input = torch.LongTensor([[self.SOS_TOKEN] * self.batch_size])
        decoder_input = decoder_input.to(self.device)

        # Set initial decoder hidden state to the encoder's final hidden state
        if self.unit == 'gru':
            decoder_hidden = encoder_hidden[:self.decoder.n_layers]
        else:
            decoder_hidden = (encoder_hidden[0][:self.decoder.n_layers],
                              encoder_hidden[1][:self.decoder.n_layers])
        if iter % conf['log_tb_every'] == 0:
            # Visualize latent space
            if self.unit == 'gru':
                vis_hidden = decoder_hidden[-1, :, :]
            else:
                vis_hidden = decoder_hidden[0][-1, :, :]
            self.writer.add_embedding(vis_hidden,
                                      tag='decoder_hidden_{}'.format(iter))

        use_teacher_forcing = True if random.random(
        ) < self.teacher_forcing_ratio else False

        if use_teacher_forcing:
            for t in range(max_target_len):
                decoder_output, decoder_hidden = self.decoder(
                    decoder_input, decoder_hidden, encoder_outputs)
                # Teacher forcing: next input is current target
                decoder_input = target_variable[t].view(1, -1)
                # Calculate and accumulate loss
                mask_loss, nTotal = self.mask_nll_loss(decoder_output,
                                                       target_variable[t],
                                                       mask[t])
                loss += mask_loss
                print_losses.append(mask_loss.item() * nTotal)
                n_totals += nTotal
        else:
            for t in range(max_target_len):
                decoder_output, decoder_hidden = self.decoder(
                    decoder_input, decoder_hidden, encoder_outputs)
                # No teacher forcing: next input is decoder's own current output
                _, topi = decoder_output.topk(1)
                decoder_input = torch.LongTensor(
                    [[topi[i][0] for i in range(self.batch_size)]])
                decoder_input = decoder_input.to(self.device)
                # Calculate and accumulate loss
                mask_loss, nTotal = self.mask_nll_loss(decoder_output,
                                                       target_variable[t],
                                                       mask[t])
                loss += mask_loss
                print_losses.append(mask_loss.item() * nTotal)
                n_totals += nTotal

        loss.backward()

        # Clip gradients: gradients are modified in place
        torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), self.clip)
        torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), self.clip)

        self.encoder_optimizer.step()
        self.decoder_optimizer.step()
        return sum(print_losses) / n_totals

    def mask_nll_loss(self, inp, target, mask):
        n_total = mask.sum()
        cross_entropy = -torch.log(
            torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
        loss = cross_entropy.masked_select(mask).sum()
        loss = loss.to(self.device)
        return loss, n_total.item()

    def save_model(self, epoch):
        directory = self.save_dir
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save(
            {
                'epoch': epoch,
                'en': self.encoder.state_dict(),
                'de': self.decoder.state_dict(),
                'en_op': self.encoder_optimizer.state_dict(),
                'de_op': self.decoder_optimizer.state_dict(),
                'embedding': self.embedding.state_dict()
            }, '{}{}-{}-{}-{}.pth'.format(directory, self.model_code,
                                          self.modality, self.langs, epoch))

    def evaluate_all(self):
        self.encoder.eval()
        self.decoder.eval()
        searcher = GreedySearchDecoder(self.encoder, self.decoder, None,
                                       self.device, self.SOS_TOKEN)
        refs = []
        hyp = []
        for pair in self.test_pairs:
            output_words = self.evaluate(self.encoder, self.decoder, searcher,
                                         self.vocab, pair[0])
            if output_words:
                final_output = []
                for x in output_words:
                    if x == '<EOS>':
                        break
                    final_output.append(x)
                refs.append([pair[1].split()])
                hyp.append(final_output)
        bleu_scores = calculateBleuScores(refs, hyp)
        print('Bleu score: {bleu_1} | {bleu_2} | {bleu_3} | {bleu_4}'.format(
            **bleu_scores))
        eg_idx = random.choice(range(len(hyp)))
        print(hyp[eg_idx], refs[eg_idx])
        return bleu_scores['bleu_4']

    def evaluate(self,
                 encoder,
                 decoder,
                 searcher,
                 vocab,
                 sentence_or_vector,
                 max_length=conf['MAX_LENGTH']):
        with torch.no_grad():
            if self.modality == 't':  # `sentence_or_vector` ~> sentence
                # Format input sentence as a batch
                # words => indexes
                indexes_batch = [
                    indexesFromSentence(vocab, sentence_or_vector)
                ]
                if None in indexes_batch:
                    return None
                for idx, indexes in enumerate(indexes_batch):
                    indexes_batch[idx] = indexes_batch[idx] + [self.EOS_TOKEN]
                # Create lengths tensor
                lengths = torch.tensor(
                    [len(indexes) for indexes in indexes_batch])
                # Transpose dimensions of batch to match models' expectations
                input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
            else:  # `sentence_or_vector` ~> vector
                input_batch, lengths = inputVarVec([sentence_or_vector],
                                                   self.modality)

            # Use appropriate device
            input_batch = input_batch.to(self.device)
            lengths = lengths.to(self.device)
            # Decode sentence with searcher
            tokens, scores = searcher(input_batch, lengths, max_length)
            # indexes -> words
            decoded_words = [
                vocab.index2word[token.item()] for token in tokens
            ]
            return decoded_words

    def close_writer(self):
        self.writer.close()