class OrderedGatedGrnnRec(MTAMRec_model): def build_model(self): self.gru_net_ins = GraphRNN() self.gated_gnn_model = ordered_gated_GNN() with tf.variable_scope('user_behavior_emb'): user_behavior_list_embedding = self.behavior_list_embedding_dense with tf.variable_scope('neighbor_emb', reuse=tf.AUTO_REUSE): structure_emb = self.gated_gnn_model.generate_graph_emb( init_emb=user_behavior_list_embedding, now_batch_size=self.now_bacth_data_size, num_units=self.num_units, adj_in=self.adj_in, adj_out=self.adj_out, eid_emb_in=self.in_eid_embedding, eid_emb_out=self.out_eid_embedding, mask_adj_in=self.mask_adj_in, mask_adj_out=self.mask_adj_out, step=self.FLAGS.graph_step ) # batch_size, max_len, num_units * 2 with tf.variable_scope('ShortTermIntentEncoder'): # in_emb, out_emb = array_ops.split(value=structure_emb, num_or_size_splits=2, axis=2) # # structure_emb = in_emb+out_emb # structure_emb = tf.layers.dense(structure_emb,units = self.num_units) grnn_inputs = tf.concat( [user_behavior_list_embedding, structure_emb], axis=2) self.short_term_intent_temp = self.gru_net_ins.simple_grnn_net( hidden_units=self.num_units, input_data=grnn_inputs, input_length=tf.add(self.seq_length, -1)) self.short_term_intent = gather_indexes( batch_size=self.now_bacth_data_size, seq_length=self.max_len, width=self.num_units, sequence_tensor=self.short_term_intent_temp, positions=self.mask_index - 1) self.short_term_intent = self.short_term_intent self.predict_behavior_emb = layer_norm(self.short_term_intent) self.output()
class GatedGrnnRec(MTAMRec_model): def build_model(self): self.gru_net_ins = GraphRNN() self.gated_gnn_model = gated_GNN() with tf.variable_scope('user_behavior_emb'): user_behavior_list_embedding = self.behavior_list_embedding_dense for i in range(2): with tf.variable_scope('neighbor_emb_' + str(i), reuse=tf.AUTO_REUSE): structure_emb = self.gated_gnn_model.generate_graph_emb( init_emb=user_behavior_list_embedding, now_batch_size=self.now_bacth_data_size, num_units=self.num_units, adj_in=self.adj_in, adj_out=self.adj_out, step=self.FLAGS.graph_step ) # batch_size, max_len, num_units * 2 with tf.variable_scope('ShortTermIntentEncoder_' + str(i), reuse=tf.AUTO_REUSE): grnn_inputs = tf.concat( [user_behavior_list_embedding, structure_emb], axis=2) user_behavior_list_embedding = self.gru_net_ins.simple_grnn_net( hidden_units=self.num_units, input_data=grnn_inputs, input_length=tf.add(self.seq_length, -1)) self.short_term_intent = gather_indexes( batch_size=self.now_bacth_data_size, seq_length=self.max_len, width=self.num_units, sequence_tensor=user_behavior_list_embedding, positions=self.mask_index - 1) self.short_term_intent = self.short_term_intent self.predict_behavior_emb = layer_norm(self.short_term_intent) self.output()