def sent_encode(self, user_word_emb, scope=None): # b_sz, ststp, wtstp, emb_sz = tf.unstack(tf.shape(user_word_emb)) with tf.variable_scope(scope or "sent_encode"): sent_word = tf.reshape(user_word_emb, [-1, self.cfg.maxword, self.cfg.emb_size]) sent_wNUm = tf.reshape(self.ph_wNum, [-1, ]) if self.cfg.ute_sent_encode == "cnn": sent_emb = cnn_layer(sent_word, self.cfg.emb_size, filter_sizes=self.cfg.filter_sizes, num_filters=self.cfg.num_filters) elif self.cfg.ute_sent_encode == 'bigru': sent_emb = biGRU(sent_word,self.cfg.word_hidden, sent_wNUm) sent_emb = mask_attention(sent_emb, self.cfg.maxword, self.cfg.word_hidden * 2, self.cfg.atten_size, sent_wNUm) elif self.cfg.ute_sent_encode == 'bilstm': sent_emb = biLSTM(sent_word, self.cfg.word_hidden, sent_wNUm) sent_emb = mask_attention(sent_emb, self.cfg.maxword, self.cfg.word_hidden * 2, self.cfg.atten_size, sent_wNUm) else: raise ValueError("no such sent encode %s" % (self.cfg.ute_sent_encode)) return sent_emb
def doc_embbedding(self, sent_topic_emb, seqLen, scope=None): with tf.variable_scope(scope or "sent_topic_embedding"): if self.cfg.doc_encode == "bigru": birnn_sent = biGRU(sent_topic_emb, self.cfg.sent_hidden, seqLen) elif self.cfg.doc_encode == 'bilstm': birnn_sent = biLSTM(sent_topic_emb, self.cfg.sent_hidden, seqLen) else: raise ValueError("no such encoder %s" %(self.cfg.doc_encode)) doc_emb = mask_attention(birnn_sent, self.cfg.atten_size, seqLen) # (b_sz, sent_hidden * 2) return doc_emb
def sent_encode(self, user_word_emb, scope=None): with tf.variable_scope(scope or "sent_encode"): sent_word = tf.reshape(user_word_emb, [-1, self.cfg.maxword, self.cfg.emb_size]) sent_wNUm = tf.reshape(self.ph_wNum, [ -1, ]) if self.cfg.han_sent_encode == 'bigru': sent_emb = biGRU(sent_word, self.cfg.word_hidden, sent_wNUm) elif self.cfg.han_sent_encode == 'bilstm': sent_emb = biLSTM(sent_word, self.cfg.word_hidden, sent_wNUm) else: raise ValueError("no such sent encode %s" % (self.cfg.han_sent_encode)) sent_emb = mask_attention(sent_emb, self.cfg.maxword, self.cfg.word_hidden * 2, self.cfg.atten_size, sent_wNUm) # attention mechanism return sent_emb