def test_unused_field_can_be_none( self, use_edges, use_nodes, use_globals, none_field): """Checks that computation can handle non-necessary fields left None.""" input_graph = self._get_input_graph([none_field]) node_block = blocks.NodeBlock( node_model_fn=self._node_model_fn, use_received_edges=use_edges, use_sent_edges=use_edges, use_nodes=use_nodes, use_globals=use_globals) output_graph = node_block(input_graph) model_inputs = [] if use_edges: model_inputs.append( blocks.ReceivedEdgesToNodesAggregator( tf.unsorted_segment_sum)(input_graph)) model_inputs.append( blocks.SentEdgesToNodesAggregator( tf.unsorted_segment_sum)(input_graph)) if use_nodes: model_inputs.append(input_graph.nodes) if use_globals: model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph)) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.edges, output_graph.edges) self.assertEqual(input_graph.globals, output_graph.globals) with self.test_session() as sess: actual_nodes, model_inputs_out = sess.run( (output_graph.nodes, model_inputs)) expected_output_nodes = model_inputs_out * self._scale self.assertNDArrayNear(expected_output_nodes, actual_nodes, err=1e-4)
def _build(self, graph): """Builds a SpringMassSimulator. Args: graph: A graphs.GraphsTuple having, for some integers N, E, G: - edges: Nx2 tf.Tensor of [spring_constant, rest_length] for each edge. - nodes: Ex5 tf.Tensor of [x, y, v_x, v_y, is_fixed] features for each node. - globals: Gx2 tf.Tensor containing the gravitational constant. Returns: A graphs.GraphsTuple of the same shape as `graph`, but where: - edges: Holds the force [f_x, f_y] acting on each edge. - nodes: Holds positions and velocities after applying one step of Euler integration. """ receiver_nodes = blocks.broadcast_receiver_nodes_to_edges(graph) sender_nodes = blocks.broadcast_sender_nodes_to_edges(graph) spring_force_per_edge = hookes_law(receiver_nodes, sender_nodes, graph.edges[..., 0:1], graph.edges[..., 1:2]) graph = graph.replace(edges=spring_force_per_edge) spring_force_per_node = self._aggregator(graph) gravity = blocks.broadcast_globals_to_nodes(graph) updated_velocities = euler_integration(graph.nodes, spring_force_per_node + gravity, self._step_size) graph = graph.replace(nodes=updated_velocities) return graph
def test_output_values(self, use_received_edges, use_sent_edges, use_nodes, use_globals, received_edges_reducer, sent_edges_reducer): """Compares the output of a NodeBlock to an explicit computation.""" input_graph = self._get_input_graph() node_block = blocks.NodeBlock( node_model_fn=self._node_model_fn, use_received_edges=use_received_edges, use_sent_edges=use_sent_edges, use_nodes=use_nodes, use_globals=use_globals, received_edges_reducer=received_edges_reducer, sent_edges_reducer=sent_edges_reducer) output_graph = node_block(input_graph) model_inputs = [] if use_received_edges: model_inputs.append( blocks.ReceivedEdgesToNodesAggregator(received_edges_reducer)( input_graph)) if use_sent_edges: model_inputs.append( blocks.SentEdgesToNodesAggregator(sent_edges_reducer)( input_graph)) if use_nodes: model_inputs.append(input_graph.nodes) if use_globals: model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph)) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.edges, output_graph.edges) self.assertEqual(input_graph.globals, output_graph.globals) with self.test_session() as sess: output_graph_out, model_inputs_out = sess.run( (output_graph, model_inputs)) expected_output_nodes = model_inputs_out * self._scale self.assertNDArrayNear(expected_output_nodes, output_graph_out.nodes, err=1e-4)
def __init__(self, scope: str, model: GraphModel, reg_param): # Process parameters self.scope = scope self.model = model self.n_out = self.model.get_global_output_size() # Configure regularization self.regularizer = tf.contrib.layers.l2_regularizer(scale=reg_param) self.reg_linear = {"w": self.regularizer, "b": self.regularizer} # self.reg_embed = {} self.reg_embed = {"embeddings": self.regularizer} # Set up input tensors with tf.variable_scope(self.scope + "/state_input"): self.input_graphs = utils_tf.placeholders_from_data_dicts( [self.model.placeholder_graph()], force_dynamic_num_graphs=True, name="local_state") with tf.variable_scope(self.scope + "/ground_truth"): # Reinforcement learning inputs self.true_action = tf.placeholder(tf.int32, shape=(None, ), name="action") self.n_objects = tf.placeholder(tf.int32, shape=(None, ), name="n_objects") self.target_q = tf.placeholder(tf.float32, shape=(None, ), name="target_q") self.target_value = tf.placeholder(tf.float32, shape=(None, ), name="target_value") with tf.variable_scope(self.scope): # Separately embed different categorical variables as dense vectors self.encoder_module = GraphEncoder(model, name="encoder", regularizers=self.reg_embed) self.embedded_graphs = self.encoder_module(self.input_graphs) # Apply an intermediate transformation to pass information between neighboring nodes self.intermediate_graphs = DenseGraphTransform( model.hidden_edge_dimension, model.hidden_node_dimension, model.hidden_global_dimension, name="intermediate", regularizer=self.regularizer)(self.embedded_graphs) # Then apply a final transformation to produce a global output and node-level evaluations self.output_graphs = DenseGraphTransform( model.action_dimension, 1, 1, node_activation=None, global_activation=None, name="output", regularizer=self.regularizer)(self.intermediate_graphs) with tf.variable_scope(self.scope + "/outputs"): # If given a true action, get the corresponding output self.graph_indices = tf.math.cumsum(self.output_graphs.n_node, exclusive=True, name="starting_node_index") self.true_indices = self.graph_indices + self.true_action self.chosen_node_outputs = tf.reshape( tf.gather(self.output_graphs.nodes, self.true_indices, name="chosen_action_outputs"), [-1]) # In case we need a policy output, build the following tensors: # 1) a learned stochastic policy for all possible actions, # 2) the individual probability of the chosen action # 3) the log of that individual probability.""" # First, get each node's index node_indices = tf.range(tf.shape(self.output_graphs.nodes)[0]) # Then, get the index of each graphs' first action first_action_indices = self.graph_indices + self.n_objects # broadcast action indices to nodes and compare to node indices first_action_broadcast = blocks.broadcast_globals_to_nodes( self.output_graphs.replace( globals=tf.reshape(first_action_indices, [-1, 1]))) action_mask = tf.greater_equal( node_indices, tf.reshape(first_action_broadcast, [-1])) # Zero out the objects and apply softmax to the actions (treat action-nodes as logits) exp_or_zero = self.output_graphs.replace(nodes=tf.where( action_mask, tf.math.exp(self.output_graphs.nodes), tf.zeros_like(self.output_graphs.nodes))) # Sum the node values so that the global for each graph is the softmax denominator sum_nodes = blocks.GlobalBlock(lambda: tf.identity, use_edges=False, use_globals=False) softmax_graph = sum_nodes(exp_or_zero) # Then divide each node's value by that denominator, or set to 1 where denominator is 0 def node_value_to_prob(node_inputs): p = tf.div_no_nan(node_inputs[:, 0], node_inputs[:, 1]) return tf.where(p > 0, p, tf.ones_like(p)) policy_graph = blocks.NodeBlock( lambda: node_value_to_prob, use_received_edges=False, use_sent_edges=False)(softmax_graph) self.policy = policy_graph.nodes self.p_chosen = tf.gather(self.policy, self.true_indices, name="p_true_action") self.log_p_chosen = tf.log(self.p_chosen, name="logp_true_action") # Configure metrics for training and display self.TRAIN_METRIC_OPS = self.scope + "/TRAIN_METRIC_OPS" self.VAL_METRIC_OPS = self.scope + "/VAL_METRIC_OPS" self.reg_term = tf.reduce_sum( tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) tf.summary.scalar('reg_loss', self.reg_term)
num_recurrent_passes = 3 previous_graphs = input_graphs for unused_pass in range(num_recurrent_passes): previous_graphs = graph_network(previous_graphs) print(previous_graphs.nodes[0]) output_graphs = previous_graphs tvars = graph_network.trainable_variables print('') ############### # broadcast graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0]) updated_broadcast_globals_to_nodes = graphs_tuple.replace( nodes=blocks.broadcast_globals_to_nodes(graphs_tuple)) updated_broadcast_globals_to_edges = graphs_tuple.replace( edges=blocks.broadcast_globals_to_edges(graphs_tuple)) updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace( edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple)) updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace( edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple)) ############ # aggregate graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0]) reducer = tf.math.unsorted_segment_sum #######yr updated_edges_to_globals = graphs_tuple.replace( globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple))