コード例 #1
0
                    def prop_in_cut():
                        with tf.variable_scope("cut"):
                            node_vec_part = tf.dynamic_partition(
                                node_vec, self._partition_idx, 2)
                            node_vec_cut = [None] * (self._prop_step_inter + 1)
                            node_vec_cut[-1] = node_vec_part[0]

                            for tt in xrange(self._prop_step_inter):
                                # pull messages
                                node_vec_sum = [None] * self._num_edgetype

                                for ee in xrange(self._num_edgetype):
                                    # partition
                                    node_active = tf.gather(
                                        node_vec_cut[tt - 1],
                                        self._send_idx_cut[ee])

                                    if self._msg_type == "msg_embedding":
                                        # compute msg using embedding alone
                                        node_vec_sum[ee] = node_active
                                    elif self._msg_type == "msg_mlp":
                                        # compute msg using a MLP
                                        node_vec_sum[ee] = self._MLP_prop[ee](
                                            node_active)[-1]

                                # aggregate messages
                                concat_msg = tf.concat(node_vec_sum, axis=0)
                                message = aggregate(concat_msg,
                                                    self._receive_idx_cut,
                                                    self._num_node_cut_var + 1,
                                                    method=self._aggregate)

                                # update hidden states via GRU
                                if self._update_type == "MLP":
                                    node_vec_cut[tt] = self._update_func(
                                        tf.concat([
                                            message[:-1, :],
                                            node_vec_cut[tt - 1]
                                        ],
                                                  axis=1))[-1]
                                else:
                                    node_vec_cut[tt] = self._update_func(
                                        message[:-1, :], node_vec_cut[tt - 1])

                            return tf.dynamic_stitch(
                                self._stitch_idx,
                                [node_vec_cut[-2], node_vec_part[1]])
コード例 #2
0
    def _inference(self):
        with tf.variable_scope("inference"):
            input_feat = self._MLP_feat(self._node_feat)[-1]

            if self._dataset_name == "nell" or self._dataset_name == "diel":
                self._node_feat = input_feat
                self._node_vec[-1] = self._node_embedding
            else:
                self._node_feat = tf.sparse_tensor_to_dense(
                    self._node_feat, validate_indices=False)
                self._node_vec[-1] = input_feat

            for pp in xrange(self._num_pass):
                with tf.variable_scope("pass_{}".format(pp)):
                    ### parallel synchoronous propagation within clusters
                    node_vec_cluster = [[None] * (self._prop_step_intra + 1)
                                        for _ in xrange(self._num_cluster)]
                    node_vec_cluster_init = tf.split(self._node_vec[pp - 1],
                                                     self._cluster_size_var,
                                                     axis=0)

                    for ii in xrange(self._num_cluster):
                        with tf.variable_scope("cluster_{}".format(ii)):
                            # node representation
                            node_vec_cluster[ii][-1] = node_vec_cluster_init[
                                ii]

                            for tt in xrange(self._prop_step_intra):
                                # pull messages
                                node_vec_sum = [None] * self._num_edgetype

                                for ee in xrange(self._num_edgetype):
                                    node_active = tf.gather(
                                        node_vec_cluster[ii][tt - 1],
                                        self._send_idx_cluster[ii][ee])

                                    if self._msg_type == "msg_embedding":
                                        # compute msg using embedding alone
                                        node_vec_sum[ee] = node_active
                                    elif self._msg_type == "msg_mlp":
                                        # compute msg using a MLP
                                        node_vec_sum[ee] = self._MLP_prop[ee](
                                            node_active)[-1]

                                # aggregate messages
                                concat_msg = tf.concat(node_vec_sum, axis=0)
                                message = aggregate(
                                    concat_msg,
                                    self._receive_idx_cluster[ii],
                                    self._cluster_size_var[ii] + 1,
                                    method=self._aggregate)

                                # update hidden states
                                if self._update_type == "MLP":
                                    node_vec_cluster[ii][
                                        tt] = self._update_func(
                                            tf.concat([
                                                message[:-1, :],
                                                node_vec_cluster[ii][tt - 1]
                                            ],
                                                      axis=1))[-1]
                                else:
                                    node_vec_cluster[ii][
                                        tt] = self._update_func(
                                            message[:-1, :],
                                            node_vec_cluster[ii][tt - 1])

                    ### update node representation
                    node_vec = tf.concat([xx[-2] for xx in node_vec_cluster],
                                         axis=0)

                    is_cut_empty = tf.equal(tf.reduce_sum(self._partition_idx),
                                            self._num_nodes)

                    ### synchoronous propagation within cut
                    def prop_in_cut():
                        with tf.variable_scope("cut"):
                            node_vec_part = tf.dynamic_partition(
                                node_vec, self._partition_idx, 2)
                            node_vec_cut = [None] * (self._prop_step_inter + 1)
                            node_vec_cut[-1] = node_vec_part[0]

                            for tt in xrange(self._prop_step_inter):
                                # pull messages
                                node_vec_sum = [None] * self._num_edgetype

                                for ee in xrange(self._num_edgetype):
                                    # partition
                                    node_active = tf.gather(
                                        node_vec_cut[tt - 1],
                                        self._send_idx_cut[ee])

                                    if self._msg_type == "msg_embedding":
                                        # compute msg using embedding alone
                                        node_vec_sum[ee] = node_active
                                    elif self._msg_type == "msg_mlp":
                                        # compute msg using a MLP
                                        node_vec_sum[ee] = self._MLP_prop[ee](
                                            node_active)[-1]

                                # aggregate messages
                                concat_msg = tf.concat(node_vec_sum, axis=0)
                                message = aggregate(concat_msg,
                                                    self._receive_idx_cut,
                                                    self._num_node_cut_var + 1,
                                                    method=self._aggregate)

                                # update hidden states via GRU
                                if self._update_type == "MLP":
                                    node_vec_cut[tt] = self._update_func(
                                        tf.concat([
                                            message[:-1, :],
                                            node_vec_cut[tt - 1]
                                        ],
                                                  axis=1))[-1]
                                else:
                                    node_vec_cut[tt] = self._update_func(
                                        message[:-1, :], node_vec_cut[tt - 1])

                            return tf.dynamic_stitch(
                                self._stitch_idx,
                                [node_vec_cut[-2], node_vec_part[1]])

                    def no_prop():
                        return node_vec

                    # Update final representation
                    self._node_vec[pp] = tf.cond(is_cut_empty, no_prop,
                                                 prop_in_cut)

                    logger.info("Propagation pass = {}".format(pp))

            output_feat = tf.concat([self._node_vec[-2], self._node_feat],
                                    axis=1)
            self._logits = self._MLP_out(output_feat, self._dropout_rate)[-1]

            self._ops["pred_logits"] = tf.nn.softmax(self._logits)
            self._ops["val_pred_logits"] = tf.gather(self._ops["pred_logits"],
                                                     self._val_idx)
            self._ops["test_pred_logits"] = tf.gather(self._ops["pred_logits"],
                                                      self._test_idx)