Esempio n. 1
0
    def __init__(self,
                 num_unique_documents,
                 vocab_size,
                 num_topics,
                 freqs=None,
                 load_embeds=False,
                 pretrained_embeddings=False,
                 save_graph_def=True,
                 embedding_size=128,
                 num_sampled=40,
                 learning_rate=1E-3,
                 lmbda=150.,
                 alpha=None,
                 power=.75,
                 batch_size=500,
                 logdir="logdir",
                 restore=False,
                 W_in=None,
                 factors_in=None,
                 additional_features_info=[],
                 additional_features_names=[]):
        """Summary
        
        Args:
            num_unique_documents (int): Number of unique documents in your dataset
            vocab_size (int): Number of unique words/tokens in your dataset
            num_topics (int): The set number of topics to cluster your data into
            freqs (list, optional): Python list of length vocab_size with frequencies of each token
            load_embeds (bool, optional): If true, we will load embeddings from pretrained_embeddings variable
            pretrained_embeddings (np array, optional): Pretrained embeddings - shape should be (vocab_size, embedding_size)
            save_graph_def (bool, optional): If true, we will save the graph to logdir
            embedding_size (int, optional): Dimension of the embeddings. This will be shared between docs, words, and topics.
            num_sampled (int, optional): Negative sampling number for NCE Loss. 
            learning_rate (float, optional): Learning rate for optimizer
            lmbda (float, optional): Strength of dirichlet prior
            alpha (None, optional): alpha of dirichlet process (defaults to 1/n_topics)
            power (float, optional): unigram sampler distortion
            batch_size (int, optional): Batch size coming into model
            logdir (str, optional): Location for models to be saved - note, we will append on the datetime too on each run
            restore (bool, optional): When True, we will restore the model from the logdir parameter's location
            W_in (None, optional): Pretrained Doc Embedding weights (shape should be [num_unique_documents, embedding_size])
            factors_in (None, optional): Pretrained Topic Embedding (shape should be [num_topics, embedding_size])
            additional_features_info (list, optional): Pass this a list of the number of unique elements
                                                       relating the the feature passed
            additional_features_names (list, optional): A list of strings of the same length of additional_features_info
                                                       that names the corresponding additional features
        
        """
        self.config = tf.ConfigProto()
        self.config.gpu_options.allow_growth = True
        self.sesh = tf.Session(config=self.config)
        self.moving_avgs = tf.train.ExponentialMovingAverage(0.9)

        self.num_unique_documents = num_unique_documents
        self.additional_features_info = additional_features_info
        self.num_additional_features = len(self.additional_features_info)
        self.additional_features_names = additional_features_names

        self.vocab_size = vocab_size
        self.num_topics = num_topics
        self.freqs = freqs
        self.load_embeds = load_embeds
        self.pretrained_embeddings = pretrained_embeddings
        self.save_graph_def = save_graph_def
        self.logdir = logdir
        self.embedding_size = embedding_size
        self.num_sampled = num_sampled
        self.learning_rate = learning_rate
        self.lmbda = lmbda
        self.alpha = alpha
        self.power = power
        self.batch_size = batch_size

        self.W_in = W_in
        self.factors_in = factors_in

        # This will be set to true if compute_normed_embeds is run so it doesnt get run more than once
        self.compute_normed = False

        if not restore:
            # Get formatted datetime from right now
            self.date = datetime.now().strftime(r"%y%m%d_%H%M")
            # Rename logdir according to current date
            self.logdir = "{}_{}".format(self.logdir, self.date)

            self.w_embed = W.Word_Embedding(
                self.embedding_size,
                self.vocab_size,
                self.num_sampled,
                load_embeds=self.load_embeds,
                pretrained_embeddings=self.pretrained_embeddings,
                freqs=self.freqs,
                power=self.power)

            # Doc and topic mixture
            self.mixture_doc = M.EmbedMixture(self.num_unique_documents,
                                              self.num_topics,
                                              self.embedding_size,
                                              name="doc")

            # Create list to hold additional feature embedding mixture objects
            self.additional_features_list = []
            # Loop through all additional features
            for feature in range(self.num_additional_features):
                # Append the embedding mixture object for each, using info and names provided
                self.additional_features_list.append(
                    M.EmbedMixture(
                        self.additional_features_info[feature],
                        self.num_topics,
                        self.embedding_size,
                        name=self.additional_features_names[feature]))
            handles = self._build_graph()

            # Add graph variables to collection
            for handle in handles:
                tf.add_to_collection(Lda2vec.RESTORE_KEY, handle)

            self.x, self.y, self.docs, self.additional_features, self.step, self.switch_loss, self.pivot, self.doc, self.context, self.loss_word2vec, self.fraction, self.loss_lda, self.loss, self.loss_avgs_op, self.optimizer, self.doc_embedding, self.topic_embedding, self.word_embedding, self.nce_weights, self.nce_biases, self.merged, *kg = handles

            # If we recieved additional features, then we unpack them here
            if len(kg) > 0:
                self.additional_features_list = kg[:len(kg) // 2]
                self.feature_lookup = kg[len(kg) // 2:]

        else:
            meta_graph = logdir + "/model.ckpt"
            tf.train.import_meta_graph(meta_graph + ".meta").restore(
                self.sesh, meta_graph)

            handles = self.sesh.graph.get_collection(Lda2vec.RESTORE_KEY)

            self.x, self.y, self.docs, self.additional_features, self.step, self.switch_loss, self.pivot, self.doc, self.context, self.loss_word2vec, self.fraction, self.loss_lda, self.loss, self.loss_avgs_op, self.optimizer, self.doc_embedding, self.topic_embedding, self.word_embedding, self.nce_weights, self.nce_biases, self.merged, *kg = handles

            # If we recieved additional features, then we unpack them here
            if len(kg) > 0:
                self.additional_features_list = kg[:len(kg) // 2]
                self.feature_lookup = kg[len(kg) // 2:]
Esempio n. 2
0
    def __init__(self,
                 num_unique_documents,
                 vocab_size,
                 num_topics,
                 freqs=None,
                 save_graph_def=True,
                 embedding_size=128,
                 num_sampled=40,
                 learning_rate=0.001,
                 lmbda=200.0,
                 alpha=None,
                 power=0.75,
                 batch_size=500,
                 logdir='logdir',
                 restore=False,
                 fixed_words=False,
                 factors_in=None,
                 pretrained_embeddings=None):
        """Summary
        
        Args:
            num_unique_documents (int): Number of unique documents in your dataset
            vocab_size (int): Number of unique words/tokens in your dataset
            num_topics (int): The set number of topics to cluster your data into
            freqs (list, optional): Python list of length vocab_size with frequencies of each token
            save_graph_def (bool, optional): If true, we will save the graph to logdir
            embedding_size (int, optional): Dimension of the embeddings. This will be shared between docs, words, and topics.
            num_sampled (int, optional): Negative sampling number for NCE Loss.
            learning_rate (float, optional): Learning rate for optimizer
            lmbda (float, optional): Strength of dirichlet prior
            alpha (None, optional): alpha of dirichlet process (defaults to 1/n_topics)
            power (float, optional): unigram sampler distortion
            batch_size (int, optional): Batch size coming into model
            logdir (str, optional): Location for models to be saved - note, we will append on the datetime too on each run
            restore (bool, optional): When True, we will restore the model from the logdir parameter's location
            fixed_words (bool, optional): Description
            factors_in (None, optional): Pretrained Topic Embedding (shape should be [num_topics, embedding_size])
            pretrained_embeddings (None, optional): Description
        
        """
        self.config = tf.ConfigProto()
        self.config.gpu_options.allow_growth = True
        self.sesh = tf.Session(config=self.config)
        self.moving_avgs = tf.train.ExponentialMovingAverage(0.9)
        self.num_unique_documents = num_unique_documents
        self.vocab_size = vocab_size
        self.num_topics = num_topics
        self.freqs = freqs
        self.save_graph_def = save_graph_def
        self.logdir = logdir
        self.embedding_size = embedding_size
        self.num_sampled = num_sampled
        self.learning_rate = learning_rate
        self.lmbda = lmbda
        self.alpha = alpha
        self.power = power
        self.batch_size = batch_size
        self.pretrained_embeddings = pretrained_embeddings
        self.factors_in = factors_in
        self.compute_normed = False
        self.fixed_words = fixed_words

        if not restore:
            self.date = datetime.now().strftime('%y%m%d_%H%M')
            self.logdir = ('{}_{}').format(self.logdir, self.date)

            # Load pretrained embeddings if provided.
            if isinstance(pretrained_embeddings, np.ndarray):
                W_in = tf.constant(
                    pretrained_embeddings,
                    name="word_embedding",
                    dtype=tf.float32) if fixed_words else tf.get_variable(
                        "word_embedding",
                        shape=[self.vocab_size, self.embedding_size],
                        initializer=tf.constant_initializer(
                            pretrained_embeddings))
            else:
                W_in = None

            # Initialize the word embedding
            self.w_embed = W.Word_Embedding(self.embedding_size,
                                            self.vocab_size,
                                            self.num_sampled,
                                            W_in=W_in,
                                            freqs=self.freqs,
                                            power=self.power)
            # Initialize the Topic-Document Mixture
            self.mixture = M.EmbedMixture(self.num_unique_documents,
                                          self.num_topics, self.embedding_size)

            # Builds the graph and returns variables within it
            handles = self._build_graph()

            for handle in handles:
                tf.add_to_collection(Lda2vec.RESTORE_KEY, handle)

            # Add Word Embedding Variables to collection
            tf.add_to_collection(Lda2vec.RESTORE_KEY, self.w_embed.embedding)
            tf.add_to_collection(Lda2vec.RESTORE_KEY, self.w_embed.nce_weights)
            tf.add_to_collection(Lda2vec.RESTORE_KEY, self.w_embed.nce_biases)

            # Add Doc Mixture Variables to collection
            tf.add_to_collection(Lda2vec.RESTORE_KEY,
                                 self.mixture.doc_embedding)
            tf.add_to_collection(Lda2vec.RESTORE_KEY,
                                 self.mixture.topic_embedding)

            (self.x, self.y, self.docs, self.step, self.switch_loss,
             self.word_context, self.doc_context, self.loss_word2vec,
             self.fraction, self.loss_lda, self.loss, self.loss_avgs_op,
             self.optimizer, self.merged) = handles

        else:
            meta_graph = logdir + '/model.ckpt'
            tf.train.import_meta_graph(meta_graph + '.meta').restore(
                self.sesh, meta_graph)
            handles = self.sesh.graph.get_collection(Lda2vec.RESTORE_KEY)

            (self.x, self.y, self.docs, self.step, self.switch_loss,
             self.word_context, self.doc_context, self.loss_word2vec,
             self.fraction, self.loss_lda, self.loss, self.loss_avgs_op,
             self.optimizer, self.merged, embedding, nce_weights, nce_biases,
             doc_embedding, topic_embedding) = handles

            self.w_embed = W.Word_Embedding(self.embedding_size,
                                            self.vocab_size,
                                            self.num_sampled,
                                            W_in=embedding,
                                            freqs=self.freqs,
                                            power=self.power,
                                            nce_w_in=nce_weights,
                                            nce_b_in=nce_biases)

            # Initialize the Topic-Document Mixture
            self.mixture = M.EmbedMixture(self.num_unique_documents,
                                          self.num_topics,
                                          self.embedding_size,
                                          W_in=doc_embedding,
                                          factors_in=topic_embedding)