예제 #1
0
    def _init_models(self, params):
        # self.global_step = None
        with tf.variable_scope('encodervae'):
            encodervae_inputs = (self.x_enc_inp, self.x_dec_inp, self.x_dec_out, self.global_step)
            params['max_len'] = args.enc_max_len
            params['max_dec_len'] = args.enc_max_len + 1
            self.encoder_model = BaseVAE(params, encodervae_inputs, "encoder")
        with tf.variable_scope('decodervae'):
            decodervae_inputs = (self.y_enc_inp, self.y_dec_inp, self.y_dec_out, self.global_step)
            params['max_len'] = args.dec_max_len
            params['max_dec_len'] = args.dec_max_len + 1
            
            if args.isPointer:
                mask_oovs = self.encoder_model.dec_seq_len_mask
                self.decoder_model = BaseVAE(params, decodervae_inputs, "decoder", 
                            self.encoder_model.encoder_outputs, self.x_enc_inp_oovs, self.max_oovs, mask_oovs)
            elif args.isContext:
                self.decoder_model = BaseVAE(params, decodervae_inputs, "decoder", self.encoder_model.encoder_outputs)
            else:
                self.decoder_model = BaseVAE(params, decodervae_inputs, "decoder")

        with tf.variable_scope('transformer'):
            self.transformer = Transformer(self.encoder_model, self.decoder_model, params['graph_type'], self.global_step)
        with tf.variable_scope('decodervae/decoding', reuse=True):
            self.training_logits = self.decoder_model._decoder_training(self.transformer.predition, reuse=True)
            self.predicted_ids_op, self.attens = self.decoder_model._decoder_inference(self.transformer.predition)
예제 #2
0
def get_model(input_size, output_size, config):
    model = Transformer(
            input_size, # Source vocabulary size
            config.hidden_size, # Transformer doesn't need word_vec_size.
            output_size, # Target vocabulary size
            n_splits=config.n_splits, # Number of head in Multi-head Attention.
            n_enc_blocks=config.n_layers,# Number of encoder blocks
            n_dec_blocks=config.n_layers,# Number of decoder blocks
            dropout_p=config.dropout, # Dropout rate on each block
        )
    return model
예제 #3
0
def get_model(input_size, output_size, train_config):
    model = Transformer(
            input_size,
            train_config.hidden_size,
            output_size,
            n_splits=train_config.n_splits,
            n_enc_blocks=train_config.n_layers,
            n_dec_blocks=train_config.n_layers,
            dropout_p=train_config.dropout,
        )
    model.load_state_dict(saved_data['model'])
    model.eval() 
    return model
예제 #4
0
import torch


def data_gen(n_vocab, batch_size, n_batch, device):
    for i in range(n_batch):
        data = torch.randint(2, n_vocab, [batch_size, 10])
        data[:, 0] = 1
        data[:, -2:] = 0
        data = data.to(device)
        yield Batch(data, data)


if __name__ == '__main__':
    n_vocab = 10

    model = Transformer(n_vocab)
    criterion = LabelSmoothing(n_vocab, 0.)
    optimizer = scheduled_adam_optimizer(model)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #device = 'cpu'
    model.to(device)

    for epoch in range(10):
        print("Epoch: {}".format(epoch))
        data_iter = data_gen(n_vocab, 128, 10000, device)
        run_epoch(data_iter, model, criterion, optimizer)

    in_seq = torch.LongTensor([[1, 7, 5, 2, 3, 4, 5, 0]]).to(device)
    out_seq = torch.zeros([1, 20], dtype=torch.int64).to(device)
    out_seq[:, 0] = 1
    model.eval()
예제 #5
0
def run(model_dir, max_len, source_train_path, target_train_path,
        source_val_path, target_val_path, enc_max_vocab, dec_max_vocab,
        encoder_emb_size, decoder_emb_size, encoder_units, decoder_units,
        batch_size, epochs, learning_rate, decay_step, decay_percent,
        log_interval, save_interval, compare_interval):

    train_iter, val_iter, source_vocab, target_vocab = create_dataset(
        batch_size, enc_max_vocab, dec_max_vocab, source_train_path,
        target_train_path, source_val_path, target_val_path)
    transformer = Transformer(max_length=max_len,
                              enc_vocab=source_vocab,
                              dec_vocab=target_vocab,
                              enc_emb_size=encoder_emb_size,
                              dec_emb_size=decoder_emb_size,
                              enc_units=encoder_units,
                              dec_units=decoder_units)
    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(transformer.parameters(), lr=learning_rate)
    lr_decay = StepLR(opt, step_size=decay_step, gamma=decay_percent)

    if torch.cuda.is_available():
        transformer.cuda()
        loss_fn.cuda()

    def training_update_function(batch):
        transformer.train()
        lr_decay.step()
        opt.zero_grad()

        softmaxed_predictions, predictions = transformer(batch.src, batch.trg)

        flattened_predictions = predictions.view(-1, len(target_vocab.itos))
        flattened_target = batch.trg.view(-1)

        loss = loss_fn(flattened_predictions, flattened_target)

        loss.backward()
        opt.step()

        return softmaxed_predictions.data, loss.data[0], batch.trg.data

    def validation_inference_function(batch):
        transformer.eval()
        softmaxed_predictions, predictions = transformer(batch.src, batch.trg)

        flattened_predictions = predictions.view(-1, len(target_vocab.itos))
        flattened_target = batch.trg.view(-1)

        loss = loss_fn(flattened_predictions, flattened_target)

        return loss.data[0]

    trainer = Trainer(train_iter, training_update_function, val_iter,
                      validation_inference_function)
    trainer.add_event_handler(TrainingEvents.TRAINING_STARTED,
                              restore_checkpoint_hook(transformer, model_dir))
    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED,
                              log_training_simple_moving_average,
                              window_size=10,
                              metric_name="CrossEntropy",
                              should_log=lambda trainer: trainer.
                              current_iteration % log_interval == 0,
                              history_transform=lambda history: history[1])
    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED,
                              save_checkpoint_hook(transformer, model_dir),
                              should_save=lambda trainer: trainer.
                              current_iteration % save_interval == 0)
    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED,
                              print_current_prediction_hook(target_vocab),
                              should_print=lambda trainer: trainer.
                              current_iteration % compare_interval == 0)
    trainer.add_event_handler(TrainingEvents.VALIDATION_COMPLETED,
                              log_validation_simple_moving_average,
                              window_size=10,
                              metric_name="CrossEntropy")
    trainer.add_event_handler(TrainingEvents.TRAINING_COMPLETED,
                              save_checkpoint_hook(transformer, model_dir),
                              should_save=lambda trainer: True)
    trainer.run(max_epochs=epochs, validate_every_epoch=True)
예제 #6
0
 def test_optimizer(self):
     model = Transformer(6)
     optimizer = scheduled_adam_optimizer(model)
예제 #7
0
def run(model_dir, max_len, source_train_path, target_train_path,
        source_val_path, target_val_path, enc_max_vocab, dec_max_vocab,
        encoder_emb_size, decoder_emb_size, encoder_units, decoder_units,
        batch_size, epochs, learning_rate, decay_step, decay_percent,
        val_interval, save_interval, compare_interval):

    logging.basicConfig(filename="validation.log",
                        filemode="w",
                        level=logging.INFO)

    train_iter, val_iter, source_vocab, target_vocab = create_dataset(
        batch_size, enc_max_vocab, dec_max_vocab, source_train_path,
        target_train_path, source_val_path, target_val_path)
    transformer = Transformer(max_length=max_len,
                              enc_vocab=source_vocab,
                              dec_vocab=target_vocab,
                              enc_emb_size=encoder_emb_size,
                              dec_emb_size=decoder_emb_size,
                              enc_units=encoder_units,
                              dec_units=decoder_units)
    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(transformer.parameters(), lr=learning_rate)
    lr_decay = StepLR(opt, step_size=decay_step, gamma=decay_percent)

    if torch.cuda.is_available():
        transformer.cuda()
        loss_fn.cuda()

    def training_step(engine, batch):
        transformer.train()
        lr_decay.step()
        opt.zero_grad()

        _, predictions = transformer(batch.src, batch.trg)

        flattened_predictions = predictions.view(-1, len(target_vocab.itos))
        flattened_target = batch.trg.view(-1)

        loss = loss_fn(flattened_predictions, flattened_target)

        loss.backward()
        opt.step()

        return loss.cpu().item()

    def validation_step(engine, batch):
        transformer.eval()
        with torch.no_grad():
            softmaxed_predictions, predictions = transformer(
                batch.src, batch.trg)

            flattened_predictions = predictions.view(-1,
                                                     len(target_vocab.itos))
            flattened_target = batch.trg.view(-1)

            loss = loss_fn(flattened_predictions, flattened_target)

            if not engine.state.output:
                predictions = softmaxed_predictions.argmax(
                    -1).cpu().numpy().tolist()
                targets = batch.trg.cpu().numpy().tolist()
            else:
                predictions = engine.state.output[
                    "predictions"] + softmaxed_predictions.argmax(
                        -1).cpu().numpy().tolist()
                targets = engine.state.output["targets"] + batch.trg.cpu(
                ).numpy().tolist()

            return {
                "loss": loss.cpu().item(),
                "predictions": predictions,
                "targets": targets
            }

    trainer = Engine(training_step)
    evaluator = Engine(validation_step)
    checkpoint_handler = ModelCheckpoint(model_dir,
                                         "Transformer",
                                         save_interval=save_interval,
                                         n_saved=10,
                                         require_empty=False)

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    # Attach training metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "train_loss")
    # Attach validation metrics
    RunningAverage(output_transform=lambda x: x["loss"]).attach(
        evaluator, "val_loss")

    pbar = ProgressBar()
    pbar.attach(trainer, ["train_loss"])

    # trainer.add_event_handler(Events.TRAINING_STARTED,
    #                           restore_checkpoint_hook(transformer, model_dir))
    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              handler=validation_result_hook(
                                  evaluator,
                                  val_iter,
                                  target_vocab,
                                  val_interval,
                                  logger=logging.info))

    trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  "nmt": {
                                      "transformer": transformer,
                                      "opt": opt,
                                      "lr_decay": lr_decay
                                  }
                              })

    # Run the prediction
    trainer.run(train_iter, max_epochs=epochs)
예제 #8
0
from modules.transformer import Transformer, create_masks
import tensorflow as tf

if __name__ == '__main__':
    sample_transformer = Transformer(num_layers=2,
                                     d_model=512,
                                     num_heads=8,
                                     dff=2048,
                                     input_size=50,
                                     output_size=512,
                                     pe_input=10000,
                                     pe_target=6000)

    temp_input = tf.random.uniform((3, 62), maxval=20, dtype=tf.int32)  # token
    temp_target = tf.random.uniform((3, 90, 512))
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
        temp_input, temp_target)

    prenet_output, stops, post_output, attention_weights = sample_transformer(
        temp_input,
        temp_target,
        training=True,
        enc_padding_mask=enc_padding_mask,
        look_ahead_mask=combined_mask,
        dec_padding_mask=dec_padding_mask)

    print(post_output.shape)  # (batch_size, tar_seq_len, target_vocab_size)
예제 #9
0
class VAESEQ:
    def __init__(self, params):
        self.params = params
        self.encoder_model = None
        self.decoder_model = None
        self.transformer_model = None
        self._build_inputs()
        self._init_models(params)
        self._loss_optimizer(params['loss_type'])
        self.build_trans_loss(params['loss_type'])
        # self.print_parameters()


    def print_parameters(self):
        print("print_parameters:")
        for item in tf.global_variables():
            print('%s: %s' % (item.name, item.get_shape()))
    
    def _build_inputs(self):
        with tf.variable_scope('placeholder'):
            # placeholders x
            self.x_enc_inp = tf.placeholder(tf.int32, [None, args.enc_max_len], name="x_enc_inp")
            self.x_dec_inp = tf.placeholder(tf.int32, [None, args.enc_max_len+1], name="x_dec_inp")
            self.x_dec_out = tf.placeholder(tf.int32, [None, args.enc_max_len+1], name="x_dec_out")
            # placeholders y
            self.y_enc_inp = tf.placeholder(tf.int32, [None, args.dec_max_len], name="y_enc_inp")
            self.y_dec_inp = tf.placeholder(tf.int32, [None, args.dec_max_len+1], name="y_dec_inp")
            self.y_dec_out = tf.placeholder(tf.int32, [None, args.dec_max_len+1], name="y_dec_out")
            # attention data
            # self.attention_data = tf.placeholder(tf.int32, [None, args.dec_max_len+1, args.enc_max_len], name="atten_data")
            if args.isPointer:
                self.x_enc_inp_oovs = tf.placeholder(tf.int32, [None, args.enc_max_len], name="x_enc_inp_oovs")
                self.max_oovs = tf.placeholder(tf.int32, None, name="max_oovs")
            # train step
            self.global_step = tf.Variable(0, trainable=False)

    def _init_models(self, params):
        # self.global_step = None
        with tf.variable_scope('encodervae'):
            encodervae_inputs = (self.x_enc_inp, self.x_dec_inp, self.x_dec_out, self.global_step)
            params['max_len'] = args.enc_max_len
            params['max_dec_len'] = args.enc_max_len + 1
            self.encoder_model = BaseVAE(params, encodervae_inputs, "encoder")
        with tf.variable_scope('decodervae'):
            decodervae_inputs = (self.y_enc_inp, self.y_dec_inp, self.y_dec_out, self.global_step)
            params['max_len'] = args.dec_max_len
            params['max_dec_len'] = args.dec_max_len + 1
            
            if args.isPointer:
                mask_oovs = self.encoder_model.dec_seq_len_mask
                self.decoder_model = BaseVAE(params, decodervae_inputs, "decoder", 
                            self.encoder_model.encoder_outputs, self.x_enc_inp_oovs, self.max_oovs, mask_oovs)
            elif args.isContext:
                self.decoder_model = BaseVAE(params, decodervae_inputs, "decoder", self.encoder_model.encoder_outputs)
            else:
                self.decoder_model = BaseVAE(params, decodervae_inputs, "decoder")

        with tf.variable_scope('transformer'):
            self.transformer = Transformer(self.encoder_model, self.decoder_model, params['graph_type'], self.global_step)
        with tf.variable_scope('decodervae/decoding', reuse=True):
            self.training_logits = self.decoder_model._decoder_training(self.transformer.predition, reuse=True)
            self.predicted_ids_op, self.attens = self.decoder_model._decoder_inference(self.transformer.predition)
    
    def _gradient_clipping(self, loss_op):
        params = tf.trainable_variables()
        # print("_gradient_clipping")
        # print(len(params))
        # for item in params:
        #     print('%s: %s' % (item.name, item.get_shape()))
        gradients = tf.gradients(loss_op, params)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, args.clip_norm)
        # print(len(clipped_gradients))
        # print(_)
        # for item in clipped_gradients[1:]:
        #     print('%s: %s' % (item.name, item.get_shape()))
        # print("_gradient_clipping")
        return clipped_gradients, params

    def _loss_optimizer(self, model_type):
        with tf.variable_scope('merge_loss'):
            mask_fn = lambda l : tf.sequence_mask(l, args.dec_max_len + 1, dtype=tf.float32)
            dec_seq_len = tf.count_nonzero(self.y_dec_out, 1, dtype=tf.int32)
            mask = mask_fn(dec_seq_len) # b x t = 64 x ?
            self.merged_loss_seq =  tf.reduce_sum(tf.contrib.seq2seq.sequence_loss(
                logits = self.training_logits,
                targets = self.y_dec_out,
                weights = mask,
                average_across_timesteps = False,
                average_across_batch = True))
            if model_type == 0:
                self.merged_loss = self.transformer.merged_mse*1000 + self.encoder_model.loss + self.decoder_model.loss
                self.merged_loss_transformer = self.transformer.merged_mse
            elif model_type == 1:
                self.merged_loss = self.transformer.wasserstein_loss*1000 + self.encoder_model.loss + self.decoder_model.loss
                self.merged_loss_transformer = self.transformer.wasserstein_loss
            elif model_type == 2:
                self.merged_loss = self.merged_loss_seq + self.encoder_model.loss + self.decoder_model.loss
                self.merged_loss_transformer = self.merged_loss_seq
            elif model_type == 3:
                self.merged_loss = self.transformer.kl_loss + self.encoder_model.loss + self.decoder_model.loss
                self.merged_loss_transformer = self.transformer.kl_loss


        with tf.variable_scope('optimizer'):
            # self.global_step = tf.Variable(0, trainable=False)
            clipped_gradients, params = self._gradient_clipping(self.merged_loss)
            self.merged_train_op = tf.train.AdamOptimizer().apply_gradients(
                zip(clipped_gradients, params), global_step=self.global_step)

            clipped_gradients, params = self._gradient_clipping(self.merged_loss_transformer)
            self.merged_train_op_transformer = tf.train.AdamOptimizer().apply_gradients(
                zip(clipped_gradients, params), global_step=self.global_step)
            
        with tf.variable_scope('summary'):
            tf.summary.scalar("trans_loss", self.merged_loss_seq)
            tf.summary.scalar("merged_loss", self.merged_loss)
            tf.summary.histogram("z_predition", self.transformer.predition)
            self.merged_summary_op = tf.summary.merge_all()

    def build_trans_loss(self, loss_type):
        if self.params['loss_type'] == 0:
            train_loss = self.transformer.merged_mse
        elif self.params['loss_type'] == 1:
            train_loss = self.transformer.wasserstein_loss
        elif self.params['loss_type'] == 2:
            train_loss = self.merged_loss_seq
        elif self.params['loss_type'] == 3:
            train_loss = self.transformer.kl_loss
        with tf.variable_scope('transformer'):
            self.transformer.build_loss(train_loss)

    def show_parameters(self, sess):
        with open("logs/param_log.txt", "a") as f:
            f.write("==============================\n")
            params = tf.trainable_variables()
            train_params_names = [p.name for p in params]
            params_values = sess.run(train_params_names)
            for name, value in zip(train_params_names, params_values):
                # print(name)
                # print(value)
                f.write(name+"\n"+str(value)+"\n")

    def train_encoder(self, sess, x_enc_inp, x_dec_inp, x_dec_out, y_enc_inp, y_dec_inp, y_dec_out):
        feed_dict = {
            self.x_enc_inp: x_enc_inp,
            self.x_dec_inp: x_dec_inp,
            self.x_dec_out: x_dec_out,
            self.y_enc_inp: y_enc_inp,
            self.y_dec_inp: y_dec_inp,
            self.y_dec_out: y_dec_out
        }
        log = self.encoder_model.train_session(sess, feed_dict)
        return log
        
    def train_decoder(self, sess, x_enc_inp, x_dec_inp, x_dec_out, y_enc_inp, y_dec_inp, y_dec_out):
        feed_dict = {
            self.x_enc_inp: x_enc_inp,
            self.x_dec_inp: x_dec_inp,
            self.x_dec_out: x_dec_out,
            self.y_enc_inp: y_enc_inp,
            self.y_dec_inp: y_dec_inp,
            self.y_dec_out: y_dec_out
        }
        log = self.decoder_model.train_session(sess, feed_dict)
        return log
    
    def train_transformer(self, sess, x_enc_inp, x_dec_inp, x_dec_out, y_enc_inp, y_dec_inp, y_dec_out):
        feed_dict = {
            self.x_enc_inp: x_enc_inp,
            self.x_dec_inp: x_dec_inp,
            self.x_dec_out: x_dec_out,
            self.y_enc_inp: y_enc_inp,
            self.y_dec_inp: y_dec_inp,
            self.y_dec_out: y_dec_out
        }
        if self.params['loss_type'] == 0:
            train_loss = self.transformer.merged_mse
        elif self.params['loss_type'] == 1:
            train_loss = self.transformer.wasserstein_loss
        elif self.params['loss_type'] == 2:
            train_loss = self.merged_loss_seq
        elif self.params['loss_type'] == 3:
            train_loss = self.transformer.kl_loss
        log = self.transformer.train_session(sess, feed_dict, train_loss)
        return log
        
    def merged_train(self, sess, x_enc_inp, x_dec_inp, x_dec_out, y_enc_inp, y_dec_inp, y_dec_out):
        feed_dict = {
            self.x_enc_inp: x_enc_inp,
            self.x_dec_inp: x_dec_inp,
            self.x_dec_out: x_dec_out,
            self.y_enc_inp: y_enc_inp,
            self.y_dec_inp: y_dec_inp,
            self.y_dec_out: y_dec_out
        }
        _, summaries, loss, trans_loss, encoder_loss, decoder_loss, step = sess.run(
            [self.merged_train_op, self.merged_summary_op, self.merged_loss, self.merged_loss_seq, self.encoder_model.loss, self.decoder_model.loss, self.global_step],
                feed_dict)
        return {'summaries': summaries, 'merged_loss': loss, 'trans_loss': trans_loss, 
            'encoder_loss': encoder_loss, 'decoder_loss': decoder_loss, 'step': step}

    def merged_seq_train(self, sess, x_enc_inp, x_dec_inp, x_dec_out, y_enc_inp, y_dec_inp, y_dec_out, x_enc_inp_oovs, max_oovs):
        if args.isPointer:
            feed_dict = {
                self.x_enc_inp: x_enc_inp, self.x_dec_inp: x_dec_inp, self.x_dec_out: x_dec_out,
                self.y_enc_inp: y_enc_inp, self.y_dec_inp: y_dec_inp, self.y_dec_out: y_dec_out,
                self.x_enc_inp_oovs: x_enc_inp_oovs, self.max_oovs: max_oovs
            }
        else:
            feed_dict = {
                self.x_enc_inp: x_enc_inp, self.x_dec_inp: x_dec_inp, self.x_dec_out: x_dec_out,
                self.y_enc_inp: y_enc_inp, self.y_dec_inp: y_dec_inp, self.y_dec_out: y_dec_out
            }
            
        _, summaries, loss, trans_loss, encoder_loss, decoder_loss, step = sess.run(
            [self.merged_train_op, self.merged_summary_op, self.merged_loss, self.merged_loss_seq, self.encoder_model.loss, self.decoder_model.loss, self.global_step],
                feed_dict)
        return {'summaries': summaries, 'merged_loss': loss, 'trans_loss': trans_loss, 
            'encoder_loss': encoder_loss, 'decoder_loss': decoder_loss, 'step': step}

    def merged_transformer_train(self, sess, x_enc_inp, x_dec_inp, x_dec_out, y_enc_inp, y_dec_inp, y_dec_out):
        feed_dict = {
            self.x_enc_inp: x_enc_inp,
            self.x_dec_inp: x_dec_inp,
            self.x_dec_out: x_dec_out,
            self.y_enc_inp: y_enc_inp,
            self.y_dec_inp: y_dec_inp,
            self.y_dec_out: y_dec_out
        }
        _, summaries, loss, step = sess.run(
            [self.merged_train_op_transformer, self.merged_summary_op, self.merged_loss_transformer, self.global_step],
                feed_dict)
        return {'summaries': summaries, 'trans_loss': loss, 'step': step}

    def show_encoder(self, sess, x, y, LOGGER):
        # self.encoder_model.generate(sess)
        infos = self.encoder_model.reconstruct(sess, x, y)
        # self.encoder_model.customized_reconstruct(sess, 'i love this film and i think it is one of the best films')
        # self.encoder_model.customized_reconstruct(sess, 'this movie is a waste of time and there is no point to watch it') 
        LOGGER.write(infos)
        print(infos.strip())

    def show_decoder(self, sess, x, y, LOGGER, x_raw):
        # self.decoder_model.generate(sess)
        feeddict = {}
        feeddict[self.x_enc_inp] = np.atleast_2d(x_raw)
        feeddict[self.y_enc_inp] = np.atleast_2d(x)
        infos = self.decoder_model.reconstruct(sess, x, y, feeddict)
        # self.decoder_model.customized_reconstruct(sess, 'i love this film and i think it is one of the best films')
        # self.decoder_model.customized_reconstruct(sess, 'this movie is a waste of time and there is no point to watch it')
        LOGGER.write(infos)
        print(infos.strip())

    def show_sample(self, sess, x, y, LOGGER):
        infos = self.transformer.sample_test(sess, x, y, self.encoder_model, self.decoder_model, self.predicted_ids_op)
        LOGGER.write(infos)
        print(infos.strip())

    def evaluation_encoder_vae(self, sess, enc_inp, outputfile):
        self.encoder_model.evaluation(sess, enc_inp, outputfile)

    def evaluation_decoder_vae(self, sess, enc_inp, outputfile):
        self.decoder_model.evaluation(sess, enc_inp, outputfile)

    def evaluation(self, sess, enc_inp, outputfile, enc_inp_oovs=None, max_oovs_len=None, data_oovs=None):
        idx2word = self.params['idx2word']
        #### method - I
        # batch_size, trans_input = sess.run([self.encoder_model._batch_size, self.encoder_model.z], {self.x_enc_inp:enc_inp})
        # predicted_decoder_z = sess.run(self.transformer.predition, {self.transformer.input:trans_input})
        #### method - I.2.0
        # batch_size, trans_input_mean, trans_input_logvar = sess.run([self.encoder_model._batch_size, self.encoder_model.z_mean, self.encoder_model.z_logvar], {self.x_enc_inp:enc_inp})
        # predicted_decoder_z = sess.run(self.transformer.predition, {self.transformer.input_mean:trans_input_mean, self.transformer.input_logvar:trans_input_logvar})
        # print("========================")
        # print(trans_input)
        # print("------------------------")
        # print(predicted_decoder_z)
        # print("========================")

        # predicted_ids_lt = sess.run(self.predicted_ids_op, 
        #     {self.decoder_model._batch_size: batch_size, self.decoder_model.z: predicted_decoder_z,
        #         self.decoder_model.enc_seq_len: [args.dec_max_len]})

        batch_size = sess.run(self.encoder_model._batch_size, {self.x_enc_inp:enc_inp})
        if args.isPointer:
            feed_dict = {
                self.x_enc_inp : enc_inp, 
                self.x_enc_inp_oovs : enc_inp_oovs,
                self.max_oovs : max_oovs_len,
                self.decoder_model.enc_seq_len : [args.dec_max_len], 
                self.decoder_model._batch_size : batch_size
            }
        else:
            feed_dict = {
                self.x_enc_inp : enc_inp, 
                self.decoder_model.enc_seq_len : [args.dec_max_len], 
                self.decoder_model._batch_size : batch_size
            }

        predicted_ids_lt = sess.run(self.predicted_ids_op, feed_dict)

        #### method - II
        # batch_size = sess.run(self.encoder_model._batch_size, {self.x_enc_inp:enc_inp})
        # predicted_ids_lt = sess.run(self.predicted_ids_op, 
        #     {self.decoder_model._batch_size: batch_size, self.x_enc_inp: enc_inp, self.y_enc_inp: enc_inp, self.decoder_model.enc_seq_len: [args.dec_max_len]})
        for i, predicted_ids in enumerate(predicted_ids_lt):
            with open(outputfile, "a") as f:
                try:
                    predicted_ids_tokens = [idx2word[idx] if idx < len(idx2word) else data_oovs[i][idx-len(idx2word)] for idx in predicted_ids]
                except:
                    predicted_ids_tokens = []
                    for idx in predicted_ids:
                        if idx < len(idx2word):
                            predicted_ids_tokens.append(idx2word[idx])
                        else:
                            if (idx-len(idx2word)) < len(data_oovs[i]):
                                predicted_ids_tokens.append(data_oovs[i][idx-len(idx2word)])
                            else:
                                predicted_ids_tokens.append("UNK")
                                print(idx, i, len(idx2word), len(data_oovs[i]), end="\t")
                                print(data_oovs)
                                # print(data_oovs[i])
                                # print(data_oovs[i][idx-len(idx2word)])
                    # raise Exception("Except")
                # result = ' '.join([idx2word[idx] for idx in predicted_ids])
                result = ' '.join(predicted_ids_tokens)
                end_index = result.find(" </S> ")
                if end_index != -1:
                    result = result[:end_index]
                f.write('%s\n' % result)
                # f.write('%s\n' % ' '.join([idx2word[idx] for idx in predicted_ids]))

    def export_attentions(self, sess, enc_inp, dec_inp):
        idx2word = self.params['idx2word']
        idx2token = self.params['idx2token']

        batch_size = sess.run(self.encoder_model._batch_size, {self.x_enc_inp:enc_inp})
        feed_dict = {
            self.x_enc_inp : enc_inp, 
            self.decoder_model.enc_seq_len : [args.dec_max_len], 
            self.decoder_model._batch_size : batch_size
        }

        attens_values, predict_y = sess.run([self.attens, self.predicted_ids_op], feed_dict)

        result = []
        for att, xtxt, ytxt in zip(attens_values, enc_inp, predict_y):
            xtxt = [idx2token[x] for x in xtxt]
            ytxt = [idx2word[x] for x in ytxt]
            result.append([att, xtxt, ytxt])
            
        return result

    def export_vectors(self, sess, enc_inp, dec_inp):
        code_mean, code_logvar = sess.run(
            [self.transformer.predition_mean, self.transformer.predition_logvar], 
            {self.x_enc_inp:enc_inp})
        desc_mean, desc_logvar = sess.run(
            [self.decoder_model.z_mean, self.decoder_model.z_logvar], 
            {self.y_enc_inp:dec_inp})
        return code_mean, code_logvar, desc_mean, desc_logvar


    # def evaluation_pointer(self, sess, enc_inp, outputfile, raw_inp):
    #     idx2word = self.params['idx2word']
        
    #     masks, aids, pids = sess.run([self.mask, self.attens_ids, self.predicted_ids], 
    #         {self.x_enc_inp:enc_inp, self.decoder_model.enc_seq_len: [args.dec_max_len], self.decoder_model._batch_size: args.batch_size})

    #     masks, aids, pids = masks[:,:,0], aids[:,:,0], pids[:,:,0]
    #     for i, (mask, aid, pid) in enumerate(zip(masks, aids, pids)):
    #         # print(i, mask, aid, pid)
    #         # print(raw_inp[i])
    #         with open(outputfile, "a") as f:
    #             result = ''
    #             for m, a, p in zip(mask, aid, pid):
    #                 # print(m, a, p)
    #                 if m == 1: result += idx2word[p] + " "
    #                 elif m ==0:
    #                     if a >= len(raw_inp[i]):
    #                         result += " UNK "
    #                     else:
    #                         result += raw_inp[i][a] + " "
    #                 else: print("ERRRRRRRRRRORR!!!")
    #             # result = ' '.join([idx2word[idx] for idx in predicted_ids])
    #             result = result.strip()
    #             # print(result)
    #             end_index = result.find(" </S> ")
    #             if end_index != -1:
    #                 result = result[:end_index]
    #             f.write('%s\n' % result)
    #             # f.write('%s\n' % ' '.join([idx2word[idx] for idx in predicted_ids]))