Example #1
0
    def __init__(self,
                 session,
                 corpus_dir,
                 knbase_dir,
                 result_dir,
                 hparams_dir=None):
        self.session = session

        hparams = HParams(hparams_dir).hparams if hparams_dir else None

        # Prepare data and hyper parameters
        print("# Prepare dataset placeholder and hyper parameters ...")
        self.tokenized_data = TokenizedData(corpus_dir=corpus_dir,
                                            hparams=hparams,
                                            knbase_dir=knbase_dir,
                                            training=False)

        self.hparams = self.tokenized_data.hparams
        self.src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        src_dataset = tf.contrib.data.Dataset.from_tensor_slices(
            self.src_placeholder)
        self.infer_batch = self.tokenized_data.get_inference_batch(src_dataset)

        # Create model
        print("# Creating inference model ...")
        self.model = ModelCreator(training=False,
                                  tokenized_data=self.tokenized_data,
                                  batch_input=self.infer_batch)
        latest_ckpt = tf.train.latest_checkpoint(result_dir)
        print("# Restoring model weights ...")
        self.model.saver.restore(session, latest_ckpt)
        self.session.run(tf.tables_initializer())
    def __init__(self,
                 corpus_dir,
                 hparams=None,
                 knbase_dir=None,
                 training=True,
                 augment_factor=3,
                 buffer_size=8192):
        """
        Args:
            corpus_dir: Name of the folder storing corpus files for training.
            hparams: The object containing the loaded hyper parameters. If None, it will be 
                    initialized here.
            knbase_dir: Name of the folder storing data files for the knowledge base. Used for 
                    inference only.
            training: Whether to use this object for training.
            augment_factor: Times the training data appears. If 1 or less, no augmentation.
            buffer_size: The buffer size used for mapping process during data processing.
        """
        if hparams is None:
            self.hparams = HParams(corpus_dir).hparams
        else:
            self.hparams = hparams

        self.src_max_len = self.hparams.src_max_len
        self.tgt_max_len = self.hparams.tgt_max_len

        self.training = training
        self.text_set = None
        self.id_set = None

        vocab_file = os.path.join(corpus_dir, VOCAB_FILE)
        self.vocab_size, _ = check_vocab(vocab_file)
        self.vocab_table = lookup_ops.index_table_from_file(
            vocab_file, default_value=self.hparams.unk_id)
        # print("vocab_size = {}".format(self.vocab_size))

        if training:
            self.case_table = prepare_case_table()
            self.reverse_vocab_table = None
            self._load_corpus(corpus_dir, augment_factor)
            self._convert_to_tokens(buffer_size)

            self.upper_words = {}
            self.stories = {}
            self.jokes = []
        else:
            self.case_table = None
            self.reverse_vocab_table = \
                lookup_ops.index_to_string_table_from_file(vocab_file,
                                                           default_value=self.hparams.unk_token)
            assert knbase_dir is not None
            knbs = KnowledgeBase()
            knbs.load_knbase(knbase_dir)
            self.upper_words = knbs.upper_words
            self.stories = knbs.stories
            self.jokes = knbs.jokes
Example #3
0
    def __init__(self, corpus_dir, hparams_dir=None):
        hparams = HParams(hparams_dir).hparams if hparams_dir else None
        self.graph = tf.Graph()
        with self.graph.as_default():
            tokenized_data = TokenizedData(corpus_dir=corpus_dir,
                                           hparams=hparams)

            self.hparams = tokenized_data.hparams
            self.train_batch = tokenized_data.get_training_batch()
            self.model = ModelCreator(training=True,
                                      tokenized_data=tokenized_data,
                                      batch_input=self.train_batch)
    def __init__(self,
                 corpus_dir,
                 hparams=None,
                 training=True,
                 buffer_size=8192):
        """
        Args:
            corpus_dir: Name of the folder storing corpus files for training.
            hparams: The object containing the loaded hyper parameters. If None, it will be 
                    initialized here.
            training: Whether to use this object for training.
            buffer_size: The buffer size used for mapping process during data processing.
        """
        if hparams is None:
            self.hparams = HParams(corpus_dir).hparams
        else:
            self.hparams = hparams

        self.src_max_len = self.hparams.src_max_len
        self.tgt_max_len = self.hparams.tgt_max_len

        self.training = training
        self.text_set = None
        self.id_set = None

        vocab_file = os.path.join(corpus_dir, VOCAB_FILE)
        self.vocab_size, _ = check_vocab(vocab_file)
        self.vocab_table = lookup_ops.index_table_from_file(
            vocab_file, default_value=self.hparams.unk_id)
        # print("vocab_size = {}".format(self.vocab_size))

        if training:
            self.case_table = prepare_case_table()
            self.reverse_vocab_table = None
            self._load_corpus(corpus_dir)
            self._convert_to_tokens(buffer_size)
        else:
            self.case_table = None
            self.reverse_vocab_table = \
                lookup_ops.index_to_string_table_from_file(vocab_file,
                                                           default_value=self.hparams.unk_token)