def build_baseline_network(self):
        '''
            @brief:
                Build the baseline network, and fetch the baseline variable list
        '''

        adj_matrix, node_attr, xml_str = model_gen.get_initial_settings()
        xml_str = etree.tostring(xml_str, pretty_print=True)

        self._node_info = gen_gnn_param.gen_gnn_param(
            self.args.task_name,
            adj_matrix,
            node_attr,
            gnn_node_option=self.args.gnn_node_option,
            root_connection_option=self.args.root_connection_option,
            gnn_output_option=self.args.gnn_output_option,
            gnn_embedding_option='parameter'
        )
        self._node_info = gnn_util.construct_ob_size_dict(self._node_info, 64)
        self._node_info = \
            gnn_util.get_inverse_type_offset(self._node_info, 'node')
        self._node_info = \
            gnn_util.get_inverse_type_offset(self._node_info, 'output')
        self._node_info = gnn_util.get_receive_send_idx(self._node_info)
        self.init_node_info = self._node_info

        self.baseline_network = pruning_network.GGNN(
            args=self.args,
            session=self.session,
            name_scope=self.name_scope + '_baseline',
            initial_node_info=self.init_node_info,
            bayesian_op=self.args.bayesian_pruning
        )
        self.dropout_mask_shape = self.baseline_network.dropout_mask_shape
        self.dropout_mask_placeholder = self.baseline_network.test_dropout_mask

        # step 2: get the placeholders for the network
        # self.target_return_placeholder = \
        # self.baseline_network.get_target_return_placeholder()
        self.vf_loss = self.baseline_network.get_vf_loss()
        self.tvars = self.baseline_network._trainable_var_list

        # define the loss and gradient
        self.optimizer = tf.train.AdamOptimizer(self.args.lr)
        self.grads = tf.gradients(self.vf_loss, self.tvars)
        self.gradient_placeholder = []
        for i_id in range(len(self.tvars)):
            self.gradient_placeholder.append(
                tf.placeholder(tf.float32, shape=self.tvars[i_id].get_shape())
            )

        self.update_op = self.optimizer.apply_gradients(
            zip(self.gradient_placeholder, self.tvars)
        )
    def get_feed_dict(self, new_species):
        adj_matrix, node_attr, xml_str = new_species['adj_matrix'], \
            new_species['node_attr'], new_species['xml_str']

        node_info = gen_gnn_param.gen_gnn_param(
            self.args.task_name,
            adj_matrix,
            node_attr,
            gnn_node_option=self.args.gnn_node_option,
            root_connection_option=self.args.root_connection_option,
            gnn_output_option=self.args.gnn_output_option,
            gnn_embedding_option='parameter'
        )
        node_info = gnn_util.construct_ob_size_dict(node_info, 64)
        node_info = gnn_util.get_inverse_type_offset(node_info, 'node')
        node_info = gnn_util.get_inverse_type_offset(node_info, 'output')
        node_info = gnn_util.get_receive_send_idx(node_info)

        dummy_obs = np.zeros([1, 6 * node_info['num_nodes'] + 6])
        _, graph_parameters, receive_idx, send_idx, \
            node_type_idx, inverse_node_type_idx, _, _, _ = \
            graph_data_util.construct_graph_input_feeddict(
                node_info, dummy_obs, -1, -1, -1, -1, -1, -1, -1
            )

        feed_dict = {
            self._receive_idx: receive_idx,
            # self._send_idx: send_idx,
            # self._node_type_idx: node_type_idx,
            self._inverse_node_type_idx: inverse_node_type_idx,
            self._batch_size_int: 1,
            # self._input_parameters: graph_parameters,
            # self._target_returns: self.data_dict[i_species_id]['LastRwd']
        }
        for i_edge in node_info['edge_type_list']:
            feed_dict[self._send_idx[i_edge]] = send_idx[i_edge]

        # append the node type idx
        for i_node_type in node_info['node_type_dict']:
            feed_dict[self._node_type_idx[i_node_type]] = \
                node_type_idx[i_node_type]

        for i_node_type in node_info['node_type_dict']:
            feed_dict[self._input_parameters[i_node_type]] = \
                graph_parameters[i_node_type]

        feed_dict[self._num_nodes_ph] = adj_matrix.shape[0]
        return feed_dict
    def _parse_mujoco_template(self):
        '''
            @brief:
                In this function, we construct the dict for node information.
                The structure is _node_info
            @attribute:
                1. general informatin about the graph
                    @self._node_info['tree']
                    @self._node_info['debug_info']
                    @self._node_info['relation_matrix']

                2. information about input output
                    @self._node_info['input_dict']:
                        self._node_info['input_dict'][id_of_node] is a list of
                        ob positions
                    @self._node_info['output_list']

                3. information about the node
                    @self._node_info['node_type_dict']:
                        self._node_info['node_type_dict']['body'] is a list of
                        node id
                    @self._node_info['num_nodes']

                4. information about the edge
                    @self._node_info['edge_type_list'] = self._edge_type_list
                        the list of edge ids
                    @self._node_info['num_edges']
                    @self._node_info['num_edge_type']

                6. information about the index
                    @self._node_info['node_in_graph_list']
                        The order of nodes if placed by types ('joint', 'body')
                    @self._node_info['inverse_node_list']
                        The inverse of 'node_in_graph_list'
                    @self._node_info['receive_idx'] = receive_idx
                    @self._node_info['receive_idx_raw'] = receive_idx_raw
                    @self._node_info['send_idx'] = send_idx

                7. information about the embedding size and ob size
                    @self._node_info['para_size_dict']
                    @self._node_info['ob_size_dict']
                        self._node_info['ob_size_dict']['root'] = 10
                        self._node_info['ob_size_dict']['joint'] = 6
            '''
        # step 0: parse the mujoco xml
        if 'evo' in self.args.task:
            self._node_info = gen_gnn_param.gen_gnn_param(
                self._task_name,
                self.adj_matrix,
                self.node_attr,
                gnn_node_option=self._gnn_node_option,
                root_connection_option=self._root_connection_option,
                gnn_output_option=self._gnn_output_option,
                gnn_embedding_option=self._gnn_embedding_option)
        else:
            self._node_info = mujoco_parser.parse_mujoco_graph(
                self._task_name,
                gnn_node_option=self._gnn_node_option,
                root_connection_option=self._root_connection_option,
                gnn_output_option=self._gnn_output_option,
                gnn_embedding_option=self._gnn_embedding_option)

        # step 1: check that the input and output size is matched
        gnn_util.io_size_check(self._input_size, self._output_size,
                               self._node_info, self._is_baseline)

        # step 2: check for ob size for each node type, construct the node dict
        self._node_info = gnn_util.construct_ob_size_dict(
            self._node_info, self._input_feat_dim)

        # step 3: get the inverse node offsets (used to construct gather idx)
        self._node_info = gnn_util.get_inverse_type_offset(
            self._node_info, 'node')

        # step 4: get the inverse node offsets (used to gather output idx)
        self._node_info = gnn_util.get_inverse_type_offset(
            self._node_info, 'output')

        # step 5: register existing edge and get the receive and send index
        self._node_info = gnn_util.get_receive_send_idx(self._node_info)