def prop_in_cut(): with tf.variable_scope("cut"): node_vec_part = tf.dynamic_partition( node_vec, self._partition_idx, 2) node_vec_cut = [None] * (self._prop_step_inter + 1) node_vec_cut[-1] = node_vec_part[0] for tt in xrange(self._prop_step_inter): # pull messages node_vec_sum = [None] * self._num_edgetype for ee in xrange(self._num_edgetype): # partition node_active = tf.gather( node_vec_cut[tt - 1], self._send_idx_cut[ee]) if self._msg_type == "msg_embedding": # compute msg using embedding alone node_vec_sum[ee] = node_active elif self._msg_type == "msg_mlp": # compute msg using a MLP node_vec_sum[ee] = self._MLP_prop[ee]( node_active)[-1] # aggregate messages concat_msg = tf.concat(node_vec_sum, axis=0) message = aggregate(concat_msg, self._receive_idx_cut, self._num_node_cut_var + 1, method=self._aggregate) # update hidden states via GRU if self._update_type == "MLP": node_vec_cut[tt] = self._update_func( tf.concat([ message[:-1, :], node_vec_cut[tt - 1] ], axis=1))[-1] else: node_vec_cut[tt] = self._update_func( message[:-1, :], node_vec_cut[tt - 1]) return tf.dynamic_stitch( self._stitch_idx, [node_vec_cut[-2], node_vec_part[1]])
def _inference(self): with tf.variable_scope("inference"): input_feat = self._MLP_feat(self._node_feat)[-1] if self._dataset_name == "nell" or self._dataset_name == "diel": self._node_feat = input_feat self._node_vec[-1] = self._node_embedding else: self._node_feat = tf.sparse_tensor_to_dense( self._node_feat, validate_indices=False) self._node_vec[-1] = input_feat for pp in xrange(self._num_pass): with tf.variable_scope("pass_{}".format(pp)): ### parallel synchoronous propagation within clusters node_vec_cluster = [[None] * (self._prop_step_intra + 1) for _ in xrange(self._num_cluster)] node_vec_cluster_init = tf.split(self._node_vec[pp - 1], self._cluster_size_var, axis=0) for ii in xrange(self._num_cluster): with tf.variable_scope("cluster_{}".format(ii)): # node representation node_vec_cluster[ii][-1] = node_vec_cluster_init[ ii] for tt in xrange(self._prop_step_intra): # pull messages node_vec_sum = [None] * self._num_edgetype for ee in xrange(self._num_edgetype): node_active = tf.gather( node_vec_cluster[ii][tt - 1], self._send_idx_cluster[ii][ee]) if self._msg_type == "msg_embedding": # compute msg using embedding alone node_vec_sum[ee] = node_active elif self._msg_type == "msg_mlp": # compute msg using a MLP node_vec_sum[ee] = self._MLP_prop[ee]( node_active)[-1] # aggregate messages concat_msg = tf.concat(node_vec_sum, axis=0) message = aggregate( concat_msg, self._receive_idx_cluster[ii], self._cluster_size_var[ii] + 1, method=self._aggregate) # update hidden states if self._update_type == "MLP": node_vec_cluster[ii][ tt] = self._update_func( tf.concat([ message[:-1, :], node_vec_cluster[ii][tt - 1] ], axis=1))[-1] else: node_vec_cluster[ii][ tt] = self._update_func( message[:-1, :], node_vec_cluster[ii][tt - 1]) ### update node representation node_vec = tf.concat([xx[-2] for xx in node_vec_cluster], axis=0) is_cut_empty = tf.equal(tf.reduce_sum(self._partition_idx), self._num_nodes) ### synchoronous propagation within cut def prop_in_cut(): with tf.variable_scope("cut"): node_vec_part = tf.dynamic_partition( node_vec, self._partition_idx, 2) node_vec_cut = [None] * (self._prop_step_inter + 1) node_vec_cut[-1] = node_vec_part[0] for tt in xrange(self._prop_step_inter): # pull messages node_vec_sum = [None] * self._num_edgetype for ee in xrange(self._num_edgetype): # partition node_active = tf.gather( node_vec_cut[tt - 1], self._send_idx_cut[ee]) if self._msg_type == "msg_embedding": # compute msg using embedding alone node_vec_sum[ee] = node_active elif self._msg_type == "msg_mlp": # compute msg using a MLP node_vec_sum[ee] = self._MLP_prop[ee]( node_active)[-1] # aggregate messages concat_msg = tf.concat(node_vec_sum, axis=0) message = aggregate(concat_msg, self._receive_idx_cut, self._num_node_cut_var + 1, method=self._aggregate) # update hidden states via GRU if self._update_type == "MLP": node_vec_cut[tt] = self._update_func( tf.concat([ message[:-1, :], node_vec_cut[tt - 1] ], axis=1))[-1] else: node_vec_cut[tt] = self._update_func( message[:-1, :], node_vec_cut[tt - 1]) return tf.dynamic_stitch( self._stitch_idx, [node_vec_cut[-2], node_vec_part[1]]) def no_prop(): return node_vec # Update final representation self._node_vec[pp] = tf.cond(is_cut_empty, no_prop, prop_in_cut) logger.info("Propagation pass = {}".format(pp)) output_feat = tf.concat([self._node_vec[-2], self._node_feat], axis=1) self._logits = self._MLP_out(output_feat, self._dropout_rate)[-1] self._ops["pred_logits"] = tf.nn.softmax(self._logits) self._ops["val_pred_logits"] = tf.gather(self._ops["pred_logits"], self._val_idx) self._ops["test_pred_logits"] = tf.gather(self._ops["pred_logits"], self._test_idx)