def build_baseline_network(self): ''' @brief: Build the baseline network, and fetch the baseline variable list ''' adj_matrix, node_attr, xml_str = model_gen.get_initial_settings() xml_str = etree.tostring(xml_str, pretty_print=True) self._node_info = gen_gnn_param.gen_gnn_param( self.args.task_name, adj_matrix, node_attr, gnn_node_option=self.args.gnn_node_option, root_connection_option=self.args.root_connection_option, gnn_output_option=self.args.gnn_output_option, gnn_embedding_option='parameter' ) self._node_info = gnn_util.construct_ob_size_dict(self._node_info, 64) self._node_info = \ gnn_util.get_inverse_type_offset(self._node_info, 'node') self._node_info = \ gnn_util.get_inverse_type_offset(self._node_info, 'output') self._node_info = gnn_util.get_receive_send_idx(self._node_info) self.init_node_info = self._node_info self.baseline_network = pruning_network.GGNN( args=self.args, session=self.session, name_scope=self.name_scope + '_baseline', initial_node_info=self.init_node_info, bayesian_op=self.args.bayesian_pruning ) self.dropout_mask_shape = self.baseline_network.dropout_mask_shape self.dropout_mask_placeholder = self.baseline_network.test_dropout_mask # step 2: get the placeholders for the network # self.target_return_placeholder = \ # self.baseline_network.get_target_return_placeholder() self.vf_loss = self.baseline_network.get_vf_loss() self.tvars = self.baseline_network._trainable_var_list # define the loss and gradient self.optimizer = tf.train.AdamOptimizer(self.args.lr) self.grads = tf.gradients(self.vf_loss, self.tvars) self.gradient_placeholder = [] for i_id in range(len(self.tvars)): self.gradient_placeholder.append( tf.placeholder(tf.float32, shape=self.tvars[i_id].get_shape()) ) self.update_op = self.optimizer.apply_gradients( zip(self.gradient_placeholder, self.tvars) )
def get_feed_dict(self, new_species): adj_matrix, node_attr, xml_str = new_species['adj_matrix'], \ new_species['node_attr'], new_species['xml_str'] node_info = gen_gnn_param.gen_gnn_param( self.args.task_name, adj_matrix, node_attr, gnn_node_option=self.args.gnn_node_option, root_connection_option=self.args.root_connection_option, gnn_output_option=self.args.gnn_output_option, gnn_embedding_option='parameter' ) node_info = gnn_util.construct_ob_size_dict(node_info, 64) node_info = gnn_util.get_inverse_type_offset(node_info, 'node') node_info = gnn_util.get_inverse_type_offset(node_info, 'output') node_info = gnn_util.get_receive_send_idx(node_info) dummy_obs = np.zeros([1, 6 * node_info['num_nodes'] + 6]) _, graph_parameters, receive_idx, send_idx, \ node_type_idx, inverse_node_type_idx, _, _, _ = \ graph_data_util.construct_graph_input_feeddict( node_info, dummy_obs, -1, -1, -1, -1, -1, -1, -1 ) feed_dict = { self._receive_idx: receive_idx, # self._send_idx: send_idx, # self._node_type_idx: node_type_idx, self._inverse_node_type_idx: inverse_node_type_idx, self._batch_size_int: 1, # self._input_parameters: graph_parameters, # self._target_returns: self.data_dict[i_species_id]['LastRwd'] } for i_edge in node_info['edge_type_list']: feed_dict[self._send_idx[i_edge]] = send_idx[i_edge] # append the node type idx for i_node_type in node_info['node_type_dict']: feed_dict[self._node_type_idx[i_node_type]] = \ node_type_idx[i_node_type] for i_node_type in node_info['node_type_dict']: feed_dict[self._input_parameters[i_node_type]] = \ graph_parameters[i_node_type] feed_dict[self._num_nodes_ph] = adj_matrix.shape[0] return feed_dict
def _parse_mujoco_template(self): ''' @brief: In this function, we construct the dict for node information. The structure is _node_info @attribute: 1. general informatin about the graph @self._node_info['tree'] @self._node_info['debug_info'] @self._node_info['relation_matrix'] 2. information about input output @self._node_info['input_dict']: self._node_info['input_dict'][id_of_node] is a list of ob positions @self._node_info['output_list'] 3. information about the node @self._node_info['node_type_dict']: self._node_info['node_type_dict']['body'] is a list of node id @self._node_info['num_nodes'] 4. information about the edge @self._node_info['edge_type_list'] = self._edge_type_list the list of edge ids @self._node_info['num_edges'] @self._node_info['num_edge_type'] 6. information about the index @self._node_info['node_in_graph_list'] The order of nodes if placed by types ('joint', 'body') @self._node_info['inverse_node_list'] The inverse of 'node_in_graph_list' @self._node_info['receive_idx'] = receive_idx @self._node_info['receive_idx_raw'] = receive_idx_raw @self._node_info['send_idx'] = send_idx 7. information about the embedding size and ob size @self._node_info['para_size_dict'] @self._node_info['ob_size_dict'] self._node_info['ob_size_dict']['root'] = 10 self._node_info['ob_size_dict']['joint'] = 6 ''' # step 0: parse the mujoco xml if 'evo' in self.args.task: self._node_info = gen_gnn_param.gen_gnn_param( self._task_name, self.adj_matrix, self.node_attr, gnn_node_option=self._gnn_node_option, root_connection_option=self._root_connection_option, gnn_output_option=self._gnn_output_option, gnn_embedding_option=self._gnn_embedding_option) else: self._node_info = mujoco_parser.parse_mujoco_graph( self._task_name, gnn_node_option=self._gnn_node_option, root_connection_option=self._root_connection_option, gnn_output_option=self._gnn_output_option, gnn_embedding_option=self._gnn_embedding_option) # step 1: check that the input and output size is matched gnn_util.io_size_check(self._input_size, self._output_size, self._node_info, self._is_baseline) # step 2: check for ob size for each node type, construct the node dict self._node_info = gnn_util.construct_ob_size_dict( self._node_info, self._input_feat_dim) # step 3: get the inverse node offsets (used to construct gather idx) self._node_info = gnn_util.get_inverse_type_offset( self._node_info, 'node') # step 4: get the inverse node offsets (used to gather output idx) self._node_info = gnn_util.get_inverse_type_offset( self._node_info, 'output') # step 5: register existing edge and get the receive and send index self._node_info = gnn_util.get_receive_send_idx(self._node_info)