def decode():
    # Load model config
    config = load_config(FLAGS)

    # Load source data to decode
    test_set = TextIterator(source=config['decode_input'],
                            batch_size=config['decode_batch_size'],
                            source_dict=config['source_vocabulary'],
                            maxlen=None,
                            n_words_source=config['num_encoder_symbols'])

    # Load inverse dictionary used in decoding
    target_inverse_dict = data_utils.load_inverse_dict(
        config['target_vocabulary'])

    # Initiate TF session
    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(allow_growth=True))) as sess:

        # Reload existing checkpoint
        model = load_model(sess, config)
        try:
            print('Decoding {}..'.format(FLAGS.decode_input))
            if FLAGS.write_n_best:
                fout = [data_utils.fopen(("%s_%d" % (FLAGS.decode_output, k)), 'w') \
                        for k in range(FLAGS.beam_width)]
            else:
                fout = [data_utils.fopen(FLAGS.decode_output, 'w')]

            for idx, source_seq in enumerate(test_set):
                source, source_len = prepare_batch(source_seq)
                # predicted_ids: GreedyDecoder; [batch_size, max_time_step, 1]
                # BeamSearchDecoder; [batch_size, max_time_step, beam_width]
                predicted_ids = model.predict(sess,
                                              encoder_inputs=source,
                                              encoder_inputs_length=source_len)

                # Write decoding results
                for k, f in reversed(list(enumerate(fout))):
                    for seq in predicted_ids:
                        f.write(
                            str(
                                data_utils.seq2words(
                                    seq[:, k], target_inverse_dict)) + '\n')
                    if not FLAGS.write_n_best:
                        break
                print('  {}th line decoded'.format(idx *
                                                   FLAGS.decode_batch_size))

            print('Decoding terminated')
        except IOError:
            pass
        finally:
            [f.close() for f in fout]
    def __init__(self,
                 source,
                 target,
                 source_dict,
                 target_dict,
                 batch_size=128,
                 maxlen=100,
                 n_words_source=-1,
                 n_words_target=-1,
                 skip_empty=False,
                 shuffle_each_epoch=False,
                 sort_by_length=True,
                 maxibatch_size=20):
        if shuffle_each_epoch:
            self.source_orig = source
            self.target_orig = target
            self.source, self.target = shuffle.main(
                [self.source_orig, self.target_orig], temporary=True)
        else:
            self.source = data_utils.fopen(source, 'r')
            self.target = data_utils.fopen(target, 'r')

        self.source_dict = load_dict(source_dict)
        self.target_dict = load_dict(target_dict)

        self.batch_size = batch_size
        self.maxlen = maxlen
        self.skip_empty = skip_empty

        self.n_words_source = n_words_source
        self.n_words_target = n_words_target

        if self.n_words_source > 0:
            for key, idx in self.source_dict.items():
                if idx >= self.n_words_source:
                    del self.source_dict[key]

        if self.n_words_target > 0:
            for key, idx in self.target_dict.items():
                if idx >= self.n_words_target:
                    del self.target_dict[key]

        self.shuffle = shuffle_each_epoch
        self.sort_by_length = sort_by_length

        self.source_buffer = []
        self.target_buffer = []
        self.k = batch_size * maxibatch_size

        self.end_of_data = False
    def __init__(self,
                 source,
                 time_step=10,
                 batch_size=20,
                 set_type=0,
                 seed=0,
                 test_set_partition=0.05,
                 balance=False,
                 noise=0):
        np.set_printoptions(threshold='nan')
        self.source = data_utils.fopen(source, 'r')
        self.time_step = time_step
        self.batch_size = batch_size
        self.iterator = 0
        self.end_of_data = False

        random.seed(seed)
        self.load(source, time_step, set_type, test_set_partition, balance,
                  noise)