def test_copying_a_component(self): # Flatten a simple 2x2 FloatBox to (4,). space = FloatBox(shape=(2,2), add_batch_rank=False) flatten_orig = Flatten() flatten_copy = flatten_orig.copy(scope="flatten-copy") component_to_test = Component(flatten_orig, flatten_copy, inputs=["input1", "input2"], outputs=["output1", "output2"], connections=[ ["input1", ["flatten", "input"]], ["input2", ["flatten-copy", "input"]], [["flatten", "output"], "output1"], [["flatten-copy", "output"], "output2"] ]) test = ComponentTest(component=component_to_test, input_spaces=dict(input1=space, input2=space)) input_ = dict( input1=np.array([[0.5, 2.0], [1.0, 2.0]]), input2=np.array([[1.0, 2.0], [3.0, 4.0]]) ) expected = dict( output1=np.array([0.5, 2.0, 1.0, 2.0]), output2=np.array([1.0, 2.0, 3.0, 4.0]) ) for i in range_(2): test.test(out_socket_names="output"+str(i+1), inputs=input_, expected_outputs=expected["output"+str(i+1)])
def test_1to1_to_2to1_component_with_constant_input_value(self): """ Adds two components in sequence, 1-to-1 and 2-to-1, to the core and blocks one of the inputs of 2-to-1 with a constant value (so that this constant value is not at the border of the core component). """ core = Component(scope="container") sub_comp1 = Dummy1to1(scope="A") sub_comp2 = Dummy2to1(scope="B") core.add_component(sub_comp1, connections=CONNECT_INS) core.add_component(sub_comp2, connections=CONNECT_OUTS) core.connect(1.1, (sub_comp2, "input1")) core.connect((sub_comp1, "output"), (sub_comp2, "input2")) test = ComponentTest(component=core, input_spaces=dict(input=float)) # Expected output: (input + 1.0) + 1.1 test.test(out_socket_names="output", inputs=78.4, expected_outputs=80.5) test.test(out_socket_names="output", inputs=-5.2, expected_outputs=-3.1)
def test_sync_socket(self): # Two Components, one with Synchronizable dropped in: # A: Can only push out values. # B: To be synced by A's values. sync_from = MyCompWithVars(scope="sync-from") sync_to = MyCompWithVars(initializer1=8.0, initializer2=7.0, scope="sync-to") # Add the Synchronizable to sync_to. sync_to.add_component(Synchronizable(), connections=CONNECT_ALL) # Create a dummy test component that contains our two Synchronizables. component_to_test = Component(name="dummy-comp") component_to_test.define_outputs("do_the_sync") component_to_test.add_components(sync_from, sync_to) # connect everything correctly component_to_test.connect((sync_from, "_variables"), (sync_to, "_values")) component_to_test.connect((sync_to, "sync"), "do_the_sync") test = ComponentTest(component=component_to_test) # Test syncing the variable from->to and check them before and after the sync. # Before the sync. test.variable_test( sync_to.get_variables(VARIABLE_NAMES), { "sync-to/" + VARIABLE_NAMES[0]: np.full(shape=sync_from.space.shape, fill_value=8.0), "sync-to/" + VARIABLE_NAMES[1]: np.full(shape=sync_from.space.shape, fill_value=7.0) }) # Now sync and re-check. test.test(out_socket_names="do_the_sync", inputs=None, expected_outputs=None) # After the sync. test.variable_test( sync_to.get_variables(VARIABLE_NAMES), { "sync-to/" + VARIABLE_NAMES[0]: np.zeros(shape=sync_from.space.shape), "sync-to/" + VARIABLE_NAMES[1]: np.ones(shape=sync_from.space.shape) })
def __init__(self, name="model", action_space=None, summary_spec=None): """ Args: name (str): The name of this GraphBuilder and of the meta-graph's core component. action_space (Optional[Space]): The action Space information to be passed into calls to each Components' `when_input_complete` methods. summary_spec (Optional[dict]): A specification dict that defines, which summaries we would like to create in the graph and register with each Component. """ # The name of this model. Our core Component gets this name. self.logger = logging.getLogger(__name__) self.name = name self.action_space = action_space self.summary_spec = parse_summary_spec(summary_spec) # All components assigned to each device, for debugging and analysis. self.device_component_assignments = dict() self.available_devices = None self.default_device = None # Counting recursive steps. self.build_steps = 0 # Create an empty core Component into which everything will be assembled by an Algo. self.core_component = Component(name=self.name, is_core=True) # A dict used for lookup of all combinations that are possible for a given set of given in-Socket # names (inside a call to `self.call`). self.input_combinations = dict() # Some registries that we need in order to build the Graph from core: # key=DataOpRecord; value=set of required DataOpRecords OR leftmost in-Sockets to calculate the key's op. self.op_record_registry = dict() # Only for core's in-Sockets. # key=in-Socket name; value=DataOp (e.g. tf.placeholder) that goes into this socket. self.in_socket_registry = dict() # key=out-Socket name; value=set of necessary in-Socket names that we need in order to calculate # the out-Socket's op output. self.out_socket_registry = dict() # Maps an out-Socket name+in-Socket/Space-combination to an actual DataOp to fetch from our Graph. self.call_registry = dict() # key=(FixMe: documentation)
def test_connecting_1to2_to_2to1(self): """ Adds two components with 1-to-2 and 2-to-1 graph_fns to the core, connects them and passes a value through it. """ core = Component(scope="container") sub_comp1 = Dummy1to2(scope="comp1") # outs=in,in+1 sub_comp2 = Dummy2to1(scope="comp2") # out =in1+in2 core.add_component(sub_comp1, connections=CONNECT_INS) core.add_component(sub_comp2, connections=CONNECT_OUTS) core.connect(sub_comp1, sub_comp2) test = ComponentTest(component=core, input_spaces=dict(input=float)) # Expected output: input + (input + 1.0) test.test(out_socket_names="output", inputs=100.9, expected_outputs=np.float32(202.8)) test.test(out_socket_names="output", inputs=-5.1, expected_outputs=np.float32(-9.2))
def test_connecting_two_1to1_components(self): """ Adds two components with 1-to-1 graph_fns to the core, connects them and passes a value through it. """ core = Component(scope="container") sub_comp1 = Dummy1to1(scope="comp1") sub_comp2 = Dummy1to1(scope="comp2") core.add_component(sub_comp1, connections=CONNECT_INS) core.add_component(sub_comp2, connections=CONNECT_OUTS) core.connect(sub_comp1, sub_comp2) test = ComponentTest(component=core, input_spaces=dict(input=float)) # Expected output: input + 1.0 + 1.0 test.test(out_socket_names="output", inputs=1.1, expected_outputs=3.1) test.test(out_socket_names="output", inputs=-5.1, expected_outputs=-3.1)
def test_connecting_in1_and_1to1_to_1to1_no_labels(self): """ Adds 4 sub-components (A, B, C, D) with 1-to-1 graph_fns to the core. in1 -> A (like preprocessor in DQN) in2 -> A A -> B (like policy in DQN) A -> C (like target policy in DQN) B -> Din1 (like loss func in DQN: q_vals_s) C -> Din2 (q_vals_sp) """ container = Component(inputs=["input1", "input2"], outputs=["output"], scope="container") a = Dummy1to1(scope="A") b = Dummy1to1(scope="B") c = Dummy1to1(scope="C") d = Dummy2to1(scope="D") # Throw in the sub-components. container.add_components(a, b, c, d) # Connect them on detailed op-level (see above for connection details). in1_through_a = a("input1") # send input1 through Component "a" in2_through_a = a("input2") # same with input2 # "Manually" split the 2 ops coming out of a via b and c. b_out = b(in1_through_a) c_out = c(in2_through_a) # Merge b_out and c_out again into D (in1 and in2 Sockets). final = d(b_out, c_out) container.define_outputs("output", final) # TODO: test = ComponentTest(component=container, input_spaces=dict(input1=float, input2=float)) # Push both inputs through graph to receive correct (single-op) output calculation. test.test(out_socket_names="output", inputs=dict(input1=np.array(1.1), input2=np.array(0.5)), expected_outputs=0.0)
def test_in_to_1to1_to_out_sock_then_to_other_1to1_then_to_other_out_sock(self): """ Adds two components (A and B, both 1-to-1) to the core. (A) connected from "input" to "output", but (B) connected from(!) "output" to "output_b". Connections schemas like this occur e.g. in our one-step optimizers. """ core = Component(scope="container") core.define_outputs("output_b") a = Dummy1to1(scope="A") b = Dummy1to1(scope="B", constant_value=1.5) core.add_component(a, connections=CONNECT_ALL) # creates "input" and "output" sockets. core.add_component(b) core.connect("output", (b, "input")) core.connect((b, "output"), "output_b") test = ComponentTest(component=core, input_spaces=dict(input=float)) # Expected output: (in + 1.0) + 1.5 test.test(out_socket_names="output", inputs=100.4, expected_outputs=np.array(101.4, dtype=np.float32)) test.test(out_socket_names="output_b", inputs=-56.2, expected_outputs=np.array(-53.7, dtype=np.float32))
def test_sync_socket_between_2_identical_comps_that_have_vars_only_in_their_sub_comps( self): """ Similar to the Policy scenario, where the Policy Component owns a NeuralNetwork (which has vars) and has to be synced with other Policies. """ # Create 2x: A custom Component (with vars) that holds another Component (with vars). # Then sync between them. comp1 = MyCompWithVars(scope="comp1") comp1.add_component(MyCompWithVars(scope="sub-comp1-with-vars"), connections=CONNECT_ALL) comp2_writable = MyCompWithVars(initializer1=3.0, initializer2=4.2, scope="comp2") comp2_writable.add_components(MyCompWithVars( initializer1=5.0, initializer2=6.2, scope="sub-comp2-with-vars"), Synchronizable(), connections=CONNECT_ALL) container = Component(comp1, comp2_writable, scope="container") container.define_outputs("do_the_sync") container.connect((comp1, "_variables"), (comp2_writable, "_values")) container.connect((comp2_writable, "sync"), (container, "do_the_sync")) test = ComponentTest(component=container) # Before the sync. test.variable_test( comp2_writable.get_variables([ "container/comp2/variable_to_sync1", "container/comp2/variable_to_sync2", "container/comp2/sub-comp2-with-vars/variable_to_sync1", "container/comp2/sub-comp2-with-vars/variable_to_sync2" ]), { "container/comp2/variable_to_sync1": np.full( shape=comp1.space.shape, fill_value=3.0, dtype=np.float32), "container/comp2/variable_to_sync2": np.full( shape=comp1.space.shape, fill_value=4.2, dtype=np.float32), "container/comp2/sub-comp2-with-vars/variable_to_sync1": np.full( shape=comp1.space.shape, fill_value=5.0, dtype=np.float32), "container/comp2/sub-comp2-with-vars/variable_to_sync2": np.full( shape=comp1.space.shape, fill_value=6.2, dtype=np.float32) }) # Now sync and re-check. test.test(out_socket_names="do_the_sync", inputs=None, expected_outputs=None) # After the sync. test.variable_test( comp2_writable.get_variables([ "container/comp2/variable_to_sync1", "container/comp2/variable_to_sync2", "container/comp2/sub-comp2-with-vars/variable_to_sync1", "container/comp2/sub-comp2-with-vars/variable_to_sync2" ]), { "container/comp2/variable_to_sync1": np.zeros(shape=comp1.space.shape, dtype=np.float32), "container/comp2/variable_to_sync2": np.ones(shape=comp1.space.shape, dtype=np.float32), "container/comp2/sub-comp2-with-vars/variable_to_sync1": np.zeros(shape=comp1.space.shape, dtype=np.float32), "container/comp2/sub-comp2-with-vars/variable_to_sync2": np.ones(shape=comp1.space.shape, dtype=np.float32) })
def test_exploration_with_discrete_action_space(self): # 2x2 action-pick, each composite action with 5 categories. action_space = IntBox(5, shape=(2, 2), add_batch_rank=True) # Our distribution to go into the Exploration object. distribution = Categorical() action_adapter = ActionAdapter() nn_output_space = FloatBox( shape=(13, ), add_batch_rank=True) # 13: Any flat nn-output should be ok. exploration = Exploration.from_spec( dict(epsilon_spec=dict(decay='linear_decay', from_=1.0, to_=0.1, start_timestep=0, num_timesteps=10000))) # The Component to test. component_to_test = Component(scope="categorical-plus-exploration") component_to_test.define_inputs("nn_output", "time_step") component_to_test.define_outputs("action") component_to_test.add_components(action_adapter, distribution, exploration) component_to_test.connect("nn_output", [action_adapter, "nn_output"]) component_to_test.connect([action_adapter, "parameters"], [distribution, "parameters"]) component_to_test.connect([distribution, "sample_deterministic"], [exploration, "sample_deterministic"]) component_to_test.connect([distribution, "sample_stochastic"], [exploration, "sample_stochastic"]) component_to_test.connect("time_step", [exploration, "time_step"]) component_to_test.connect([exploration, "action"], "action") test = ComponentTest(component=component_to_test, input_spaces=dict(nn_output=nn_output_space, time_step=int), action_space=action_space) # fake output from last NN layer (shape=(13,)) inputs = dict(nn_output=np.array([[ 100.0, 50.0, 25.0, 12.5, 6.25, 200.0, 100.0, 50.0, 25.0, 12.5, 1.0, 1.0, 25.0 ], [ 123.4, 34.7, 98.2, 1.2, 120.0, 200.0, 200.0, 0.00009, 10.0, 300.0, 0.567, 0.678, 0.789 ]]), time_step=10000) expected = np.array([[[3, 1], [3, 2]], [[1, 1], [3, 2]]]) test.test(out_socket_names="action", inputs=inputs, expected_outputs=expected)
def test_exploration_with_continuous_action_space(self): # 2x2 action-pick, each composite action with 5 categories. action_space = FloatBox(shape=(2, 2), add_batch_rank=True) distribution = Normal() action_adapter = ActionAdapter() # Our distribution to go into the Exploration object. nn_output_space = FloatBox( shape=(13, ), add_batch_rank=True) # 13: Any flat nn-output should be ok. exploration = Exploration.from_spec( dict(noise_spec=dict(type='gaussian_noise', mean=10.0, sd=2.0))) # The Component to test. component_to_test = Component(scope="continuous-plus-noise") component_to_test.define_inputs("nn_output") component_to_test.define_outputs("action", "noise") component_to_test.add_components(action_adapter, distribution, exploration) component_to_test.connect("nn_output", [action_adapter, "nn_output"]) component_to_test.connect([action_adapter, "parameters"], [distribution, "parameters"]) component_to_test.connect([distribution, "sample_deterministic"], [exploration, "sample_deterministic"]) component_to_test.connect([distribution, "sample_stochastic"], [exploration, "sample_stochastic"]) # component_to_test.connect("time_step", [exploration, "time_step"]) # Currently no noise component uses this component_to_test.connect([exploration, "action"], "action") component_to_test.connect([exploration, "noise"], "noise") test = ComponentTest(component=component_to_test, input_spaces=dict(nn_output=nn_output_space), action_space=action_space) # Collect outputs in `collected` list to compare moments. collected = list() collect_outs = lambda component_test, outs: collected.append(outs) for i in range_(1000): test.test(out_socket_names="noise", fn_test=collect_outs) self.assertAlmostEqual(10.0, np.mean(collected), places=1) self.assertAlmostEqual(2.0, np.std(collected), places=1)
class GraphBuilder(Specifiable): """ The graph builder assembles the YARL meta-graph by tracing through components, sockets and connections and creating the underlying computation graph. """ # Break graph building if caught in a loop. MAX_RECURSIVE_CALLS = 100 def __init__(self, name="model", action_space=None, summary_spec=None): """ Args: name (str): The name of this GraphBuilder and of the meta-graph's core component. action_space (Optional[Space]): The action Space information to be passed into calls to each Components' `when_input_complete` methods. summary_spec (Optional[dict]): A specification dict that defines, which summaries we would like to create in the graph and register with each Component. """ # The name of this model. Our core Component gets this name. self.logger = logging.getLogger(__name__) self.name = name self.action_space = action_space self.summary_spec = parse_summary_spec(summary_spec) # All components assigned to each device, for debugging and analysis. self.device_component_assignments = dict() self.available_devices = None self.default_device = None # Counting recursive steps. self.build_steps = 0 # Create an empty core Component into which everything will be assembled by an Algo. self.core_component = Component(name=self.name, is_core=True) # A dict used for lookup of all combinations that are possible for a given set of given in-Socket # names (inside a call to `self.call`). self.input_combinations = dict() # Some registries that we need in order to build the Graph from core: # key=DataOpRecord; value=set of required DataOpRecords OR leftmost in-Sockets to calculate the key's op. self.op_record_registry = dict() # Only for core's in-Sockets. # key=in-Socket name; value=DataOp (e.g. tf.placeholder) that goes into this socket. self.in_socket_registry = dict() # key=out-Socket name; value=set of necessary in-Socket names that we need in order to calculate # the out-Socket's op output. self.out_socket_registry = dict() # Maps an out-Socket name+in-Socket/Space-combination to an actual DataOp to fetch from our Graph. self.call_registry = dict() # key=(FixMe: documentation) def build_graph_from_meta_graph(self, available_devices, default_device): """ Builds the actual backend-specific graph from the YARL metagraph. Loops through all our sub-components starting at core and assembles the graph by creating placeholders, following Socket->Socket connections and running through our GraphFunctions to generate DataOps. Args: available_devices (list): Devices which can be used to assign parts of the graph during graph assembly. default_device (str): Default device identifier. """ # Before we start, sanity check the meta graph for obvious flaws. self.sanity_check_meta_graph() # Set devices usable for this graph. self.available_devices = available_devices self.default_device = default_device # Actually build the graph. # Push all spaces to in-Sockets, then call build_component(core) for in_sock in self.core_component.input_sockets: # type: Socket # Skip sockets already connected to constant values. if len(in_sock.op_records) == 0: self.push_space_into_socket(in_sock) space_dict = self.core_component.check_input_completeness() self.build_component(self.core_component, space_dict) # Check whether all our components and graph_fns are now input-complete. self.sanity_check_build() # Memoize possible input combinations for out-Socket `execution` calls. self.memoize_inputs() # Registers actual ops with the different out-Sockets, so we know, which ops to execute for a given # out-Socket/input-feed-data combination. self.register_ops() def sanity_check_meta_graph(self, component=None): """ Checks whether all the `component`'s and its sub-components' in-Sockets are simply connected in the meta-graph and raises detailed error messages if not. A connection to an in-Socket is ok if ... a) it's coming from another Socket or b) it's coming from a Space object Args: component (Component): The Component to analyze for incoming connections. """ component = component or self.core_component if self.logger.level <= logging.INFO: component_print_out(component) # Check all the Component's in-Sockets for being connected from a Space/Socket. for in_sock in component.input_sockets: # type: Socket if len(in_sock.incoming_connections) == 0 and \ in_sock.name not in component.unconnected_sockets_in_meta_graph: raise YARLError( "Component '{}' has in-Socket ({}) without any incoming connections! If this is " "intended before the build process, you have to add the Socket's name to the " "Component's `unconnected_sockets_in_meta_graph` set. Then this error will be " "suppressed for this Component.".format( component.name, in_sock.name)) # Check all the component's graph_fns for input-completeness. for graph_fn in component.graph_fns: # type: GraphFunction for in_sock_rec in graph_fn.input_sockets.values(): in_sock = in_sock_rec["socket"] if len(in_sock.incoming_connections) == 0 and \ in_sock.name not in component.unconnected_sockets_in_meta_graph: raise YARLError( "GraphFn {}/{} has in-Socket ({}) without any incoming " "connections!".format(component.name, graph_fn.name, in_sock_rec["socket"].name)) # Recursively call this method on all the sub-component's sub-components. for sub_component in component.sub_components.values(): self.build_steps += 1 if self.build_steps >= self.MAX_RECURSIVE_CALLS: raise YARLError( "Error sanity checking graph, reached max recursion steps: {}" .format(self.MAX_RECURSIVE_CALLS)) self.sanity_check_meta_graph(sub_component) def build_component(self, component, input_spaces): """ Called when a Component has all its incoming Spaces known. Only then can we sanity check the input Spaces and create the Component's variables. Args: component (Component): The Component that now has all its input Spaces defined. input_spaces (dict): A dict mapping all in-Socket names of `component` to a Space object. """ assert component.input_complete is True, "ERROR: Component {} is not input complete!".format( component.name) self.logger.debug( "Component {} is input-complete; space-dict={}".format( component.name, input_spaces)) # Component is complete now, allow it to sanity check its inputs and create its variables. component.when_input_complete(input_spaces, self.action_space, self.summary_spec["summaries_regexp"]) # Push forward no-input graph_fns. for graph_fn in component.no_input_graph_fns: self.push_from_graph_fn(graph_fn) # Build all sub-components that have no inputs. for sub_component in component.no_input_sub_components: # Assert input-completeness. input_spaces should be empty. input_spaces = sub_component.check_input_completeness() self.build_component(sub_component, input_spaces) # Loop through all in-Sockets' outgoing connections and push Spaces from them. for in_socket in component.input_sockets: # type: Socket # Push this Socket's information further down. self.push_from_socket(in_socket) # At the very end, build our _variables out-Socket from the special "_variables" graph_fn. variables_graph_fn = [ gf for gf in component.graph_fns if gf.name == "_variables" ][0] self.push_from_graph_fn(variables_graph_fn) def push_from_socket(self, socket): # Skip this Socket, if it doesn't have a Space (no incoming connection). # Assert that it's ok for the component to leave this Socket open. if socket.space is None: assert socket.name in socket.component.unconnected_sockets_in_meta_graph return for outgoing in socket.outgoing_connections: # Push Socket into Socket. if isinstance(outgoing, Socket): print("SOCK {}/{} -> {}/{}".format(socket.component.name, socket.name, outgoing.component.name, outgoing.name)) self.push_socket_into_socket(socket, outgoing) # Push Socket into GraphFunction. elif isinstance(outgoing, GraphFunction): self.push_from_graph_fn(outgoing) # Error. else: raise YARLError("ERROR: Outgoing connection ({}) must be of type Socket or GraphFunction!".\ format(outgoing)) def push_space_into_socket(self, socket): """ Stores Space information for a Socket. The Socket must be one of the core Component's in-Socket with the Space already connected to it in `socket.incoming_connections`. Args: socket (Socket): The Socket to receive the Space's information. """ assert socket.component.is_core,\ "ERROR: Can only push a Space into a core's in-Socket (Socket={})!".format(socket) assert socket.space is None, \ "ERROR: Can only push Space into a Socket ({}) that does not have one yet!".format(socket) assert len(socket.incoming_connections) == 1, \ "ERROR: Socket '{}' already has an incoming connection. Cannot add Space to it.".format(socket) # Store the Space as this Socket's. space = socket.incoming_connections[0] self.logger.debug("Space {} -> Socket {}/{}".format( space, socket.component.name, socket.name)) socket.space = space # Create the placeholder and wrap it in a DataOpRecord with no labels. op = space.get_tensor_variable(name=socket.name, is_input_feed=True) op_rec = DataOpRecord(op) socket.op_records.add(op_rec) # Keep track of which input op (e.g. tf.placeholder) goes into this Socket. self.in_socket_registry[socket.name] = op # Remember, that this DataOp goes into a Socket at the very beginning of the Graph (e.g. a # tf.placeholder). self.op_record_registry[op_rec] = {socket} def push_socket_into_socket(self, socket, next_socket): assert socket.space is not None assert next_socket.space is None or socket.space == next_socket.space,\ "ERROR: Socket '{}' already has Space '{}', but incoming connection '{}' has Space '{}'! " \ "Incoming Spaces must always be the same.".format(next_socket, next_socket.space, socket, socket.space) was_input_complete = next_socket.component.input_complete self.logger.debug("Socket {}/{} -> Socket {}/{}".format( socket.component.name, socket.name, next_socket.component.name, next_socket.name)) next_socket.space = socket.space # Make sure we filter those op-records that already have at least one label and that do not # have the label of this connection (from `incoming`). socket_labels = socket.labels.get(next_socket, None) # type: set op_records = socket.op_records # With filtering. if socket_labels is not None: filtered_op_records = set() for op_rec in op_records: # type: DataOpRecord # If incoming op has no labels OR it has at least 1 label out of this Socket's # labels for this connection -> Allow op through to this Socket. if len(op_rec.labels) == 0 or len( set.intersection(op_rec.labels, socket_labels)): op_rec.labels.update(socket_labels) filtered_op_records.add(op_rec) op_records = filtered_op_records next_socket.op_records.update(op_records) # Continue with the build logic. self.after_socket_update(next_socket, was_input_complete) def after_socket_update(self, socket, was_input_complete): # The Component of the Socket has already been input-complete. Keep pushing the Socket forward. if was_input_complete is True: self.push_from_socket(socket) else: # Check again for input-completeness. space_dict = socket.component.check_input_completeness() # Component has just become input-complete: Build it. if socket.component.input_complete: self.build_component(socket.component, space_dict) def push_from_graph_fn(self, graph_fn): """ Builds outgoing graph function ops using `socket`'s component's device or the GraphBuilder's default one. Args: graph_fn (GraphFunction): Graph function object to build output ops for. """ # Check for input-completeness of this graph_fn. if graph_fn.check_input_completeness(): # We have to specify the device and the variable scope here as we will be running through a # GraphFunction, which may add ops to the graph. assigned_device = graph_fn.component.device or self.default_device self.run_through_graph_fn_with_device_and_scope( graph_fn, assigned_device) # Store assigned names for debugging. if assigned_device not in self.device_component_assignments: self.device_component_assignments[assigned_device] = [ str(graph_fn) ] else: self.device_component_assignments[assigned_device].append( str(graph_fn)) # Keep moving through this graph_fn's out-Sockets (if input-complete). if graph_fn.input_complete: for slot, out_socket in enumerate(graph_fn.output_sockets): self.push_from_socket(out_socket) #, graph_fn, slot) def run_through_graph_fn_with_device_and_scope(self, graph_fn, assigned_device): """ Assigns device to the ops generated by a graph_fn. Args: graph_fn (GraphFunction): GraphFunction to assign device to. assigned_device (str): Device identifier. """ if get_backend() == "tf": if assigned_device not in self.available_devices: self.logger.error( "Assigned device {} for graph_fn {} not in available devices:\n {}" .format(assigned_device, graph_fn, self.available_devices)) # Assign proper device to all ops created in this context manager. with tf.device(assigned_device): # Name ops correctly according to our Component hierarchy. with tf.name_scope(graph_fn.component.global_scope + ( '/' if graph_fn.component.global_scope else "")): self.logger.debug( "Assigning device {} to graph_fn {} (scope {}).". format(assigned_device, graph_fn, graph_fn.component.global_scope)) self.run_through_graph_fn(graph_fn) def run_through_graph_fn(self, graph_fn): """ Pushes all incoming ops through the method of this GraphFunction object. The ops are collected from incoming Sockets and optionally flattened and/or split before pushing them through the method and the return values optionally unflattened. Args: graph_fn (GraphFunction): The GraphFunction object to run through (its method) with all possible in-Socket combinations (only those that have not run yet through the method). """ in_op_records = [ in_sock_rec["socket"].op_records for in_sock_rec in graph_fn.input_sockets.values() ] in_op_records_combinations = list(itertools.product(*in_op_records)) for in_op_record_combination in in_op_records_combinations: # Make sure we call the computation method only once per input-op combination. if in_op_record_combination in graph_fn.in_out_records_map: continue # Replace constant-value Sockets with their SingleDataOp's constant numpy values # and the DataOps with their actual ops (`op` property of DataOp). actual_call_params = [ op_rec.op.constant_value if isinstance(op_rec.op, SingleDataOp) and op_rec.op.constant_value is not None else op_rec.op for op_rec in in_op_record_combination ] # Build the ops from this input-combination. # Flatten input items. if graph_fn.flatten_ops is not False: flattened_ops = graph_fn.flatten_input_ops(*actual_call_params) # Split into SingleDataOps? if graph_fn.split_ops: call_params = split_flattened_input_ops( graph_fn.add_auto_key_as_first_param, *flattened_ops) # There is some splitting to do. Call graph_fn many times (one for each split). if isinstance(call_params, FlattenedDataOp): ops = dict() num_return_values = -1 for key, params in call_params.items(): ops[key] = force_tuple(graph_fn.method(*params)) if num_return_values >= 0 and num_return_values != len( ops[key]): raise YARLError( "Different split-runs through {} do not return the same number of " "values!".format(graph_fn.name)) num_return_values = len(ops[key]) # Un-split the results dict into a tuple of `num_return_values` slots. un_split_ops = list() for i in range(num_return_values): dict_with_singles = FlattenedDataOp() for key in call_params.keys(): dict_with_singles[key] = ops[key][i] un_split_ops.append(dict_with_singles) ops = tuple(un_split_ops) # No splitting to do: Pass everything as-is. else: ops = graph_fn.method(*call_params) else: ops = graph_fn.method(*flattened_ops) # Just pass in everything as-is. else: ops = graph_fn.method(*actual_call_params) # OBSOLETE: always must un-flatten all return values. Otherwise, we would allow Dict Spaces # with '/' keys in them, which is not allowed. #if graph_fn.unflatten_ops: ops = graph_fn.unflatten_output_ops(*force_tuple(ops)) # Make sure everything coming from a computation is always a tuple (for out-Socket indexing). ops = force_tuple(ops) # Make sure the number of returned ops matches the number of outgoing Sockets from thie graph_fn assert len(ops) == len(graph_fn.output_sockets),\ "ERROR: Number of returned values of graph_fn '{}/{}' ({}) does not match number of out-Sockets ({}) " \ "of this GraphFunction!".format(graph_fn.component.name, graph_fn.name, len(ops), len(graph_fn.output_sockets)) # ops are now the raw graph_fn output: Need to convert it back to records. new_label_set = set() for rec in in_op_record_combination: # type: DataOpRecord new_label_set.update(rec.labels) op_records = convert_ops_to_op_records(ops, labels=new_label_set) graph_fn.in_out_records_map[in_op_record_combination] = op_records # Move graph_fn results into next Socket(s). for i, (socket, op_rec) in enumerate( zip(graph_fn.output_sockets, op_records)): self.logger.debug( "GraphFn {}/{} -> return-slot {} -> {} -> Socket {}/{}". format(graph_fn.component.name, graph_fn.name, i, ops, socket.component.name, socket.name)) # Store op_rec in the respective outgoing Socket (and make sure Spaces match). space = get_space_from_op(op_rec.op) if len(socket.op_records) > 0: sanity_check_space = get_space_from_op( next(iter(socket.op_records)).op) assert space == sanity_check_space,\ "ERROR: Newly calculated output op of graph_fn '{}' has different Space than existing one " \ "({} vs {})!".format(graph_fn.name, space, sanity_check_space) else: socket.space = space socket.op_records.add(op_rec) self.op_record_registry[op_rec] = set(in_op_record_combination) # Make sure all op_records do not contain SingleDataOps with constant_values. Any # in-Socket-connected constant values need to be converted to actual ops during a graph_fn call. assert not isinstance(op_rec.op, SingleDataOp), \ "ERROR: graph_fn '{}' returned a SingleDataOp with constant_value set to '{}'! " \ "This is not allowed. All graph_fns must return actual (non-constant) ops.". \ format(graph_fn.name, op_rec.op.constant_value) def memoize_inputs(self): # Memoize possible input-combinations (from all our in-Sockets) # so we don't have to do this every time we get a call to `self.execute`. in_names = sorted( list(map(lambda s: s.name, self.core_component.input_sockets))) input_combinations = all_combinations(in_names, descending_length=True) # Store each combination and its sub-combinations in self.input_combinations. for input_combination in input_combinations: self.input_combinations[tuple(input_combination)] = \ all_combinations(input_combination, descending_length=True) def register_ops(self): # Now use the ready op/socket registries to determine for which out-Socket we need which inputs. # Then we will be able to derive the correct op for any given (out-Socket+in-Socket+in-shape)-combination # passed into the call method. for output_socket in self.core_component.output_sockets: # type: Socket # Create empty out-sock registry entry. self.out_socket_registry[output_socket.name] = set() assert len(output_socket.op_records) > 0, "ERROR: There must at least be one op-record for out-Socket " \ "'{}'!".format(output_socket.name) # Loop through this Socket's set of possible ops. for op_rec in output_socket.op_records: # Get all the (core) in-Socket names (alphabetically sorted) that are required for this op. sockets = tuple( sorted(list(self.trace_back_sockets({op_rec})), key=lambda s: s.name)) # If an in-Socket has more than one connected incoming Space: # Get the shape-combinations for these Sockets. # e.g. Sockets=["a", "b"] (and Space1 -> a, Space2 -> a, Space3 -> b) # shape-combinations=[(Space1, Space3), (Space2, Space3)] shapes = [[ i.get_shape(with_batch_rank=True) for i in sock.incoming_connections ] for sock in sockets] shape_combinations = itertools.product(*shapes) for shape_combination in shape_combinations: # Do everything by Socket-name (easier to debug). in_socket_names = tuple([s.name for s in sockets]) # Update our call registry. key = (output_socket.name, in_socket_names, shape_combination) self.call_registry[key] = op_rec.op # .. and the out-socket registry. self.out_socket_registry[output_socket.name].update( set(in_socket_names)) def sanity_check_build(self, component=None): """ Checks whether all the `component`'s and sub-components's in-Sockets and graph_fns are input-complete and raises detailed error messages if not. Input-completeness means that .. a) all in-Sockets of a Component or b) all connected incoming Sockets to a GraphFunction .. have their `self.space` field defined (is not None). Args: component (Component): The Component to analyze for input-completeness. """ component = component or self.core_component if self.logger.level <= logging.INFO: component_print_out(component) # Check all the component's graph_fns for input-completeness. for graph_fn in component.graph_fns: if graph_fn.input_complete is False: # Look for the missing in-Socket and raise an Error. for in_sock_name, in_sock_record in graph_fn.input_sockets.items( ): if len(in_sock_record["socket"].op_records) == 0 and \ in_sock_name not in component.unconnected_sockets_in_meta_graph: raise YARLError( "in-Socket '{}' of GraphFunction '{}' of Component '{}' does not have " "any incoming ops!".format(in_sock_name, graph_fn.name, component.global_scope)) # Check component's sub-components for input-completeness (recursively). for sub_component in component.sub_components.values( ): # type: Component if sub_component.input_complete is False: # Look for the missing Socket and raise an Error. for in_sock in sub_component.input_sockets: if in_sock.space is None: raise YARLError("Component '{}' is not input-complete. In-Socket '{}' does not " \ "have any incoming connections.". format(sub_component.global_scope, in_sock.name)) # Recursively call this method on all the sub-component's sub-components. self.sanity_check_build(sub_component) def get_execution_inputs(self, output_socket_names, inputs=None): """ Fetches graph inputs for execution. Args: output_socket_names (Union[str,List[str]]): A name or a list of names of the out-Sockets to fetch from our core component. inputs (Optional[dict,data]): Dict specifying the provided inputs for some in-Sockets. Depending on these given inputs, the correct backend-ops can be selected within the given (out)-Sockets. Alternatively, can pass in data directly (not as a dict), but only if there is only one in-Socket in the Model or only one of the in-Sockets is needed for the given out-Sockets. Returns: tuple: fetch-dict, feed-dict with relevant args. 9 """ output_socket_names = force_list(output_socket_names) # Sanity check out-Socket names. for out_sock_name in output_socket_names: if out_sock_name not in self.out_socket_registry: raise YARLError( "ERROR: Out-Socket '{}' not found in Model! Make sure you are fetching by the \n" "correct out-Socket name.".format(out_sock_name)) only_input_socket_name = None # the name of the only in-Socket possible here # Some input is given. if inputs is not None: # Get only in-Socket .. if len(self.core_component.input_sockets) == 1: only_input_socket_name = self.core_component.input_sockets[ 0].name # .. or only in-Socket for single(!), given out-Socket. elif len(output_socket_names) == 1 and \ len(self.out_socket_registry[output_socket_names[0]]) == 1: only_input_socket_name = next( iter(self.out_socket_registry[output_socket_names[0]])) # Check whether data is given directly. if not isinstance(inputs, dict): if only_input_socket_name is None: raise YARLError( "ERROR: Input data (`inputs`) given directly (not as dict) AND more than one \n" "in-Socket in Model OR more than one in-Socket needed for given out-Sockets '{}'!" .format(output_socket_names)) inputs = {only_input_socket_name: inputs} # Is a dict: Check whether it's a in-Socket name dict (leave as is) or a # data dict (add in-Socket name as key). else: # We have more than one necessary in-Sockets (leave as is) OR # the only necessary in-Socket name is not key of the dict -> wrap it. if only_input_socket_name is not None and only_input_socket_name not in inputs: inputs = {only_input_socket_name: inputs} # Try all possible input combinations to see whether we got an op for that. # Input Socket names will be sorted alphabetically and combined from short sequences up to longer ones. # Example: inputs={A: ..., B: ... C: ...} # input_combinations=[ABC, AB, AC, BC, A, B, C] # These combinations have been memoized for fast lookup. key = tuple(sorted(inputs.keys())) input_combinations = self.input_combinations.get(key) if not input_combinations: raise YARLError( "ERROR: At least one of the given in-Socket names {} seems to be non-existent " "in Model!".format(key)) # No input given (maybe an out-Socket that doesn't require input). else: input_combinations = list(()) # Go through each (core) out-Socket names and collect the correct ops to go into the fetch_list. fetch_list = list() feed_dict = dict() for out_socket_name in output_socket_names: # Updates with relevant ops fetch_list, feed_dict = self._get_execution_inputs_for_socket( out_socket_name, input_combinations, fetch_list, inputs, feed_dict) return fetch_list, feed_dict def _get_execution_inputs_for_socket(self, socket_name, input_combinations, fetch_list, input_dict, feed_dict): """ Helper (to avoid nested for loop-break) for the loop in get_execution_inputs. Args: socket_name (str): The name of the (core) out-Socket to process. input_combinations (List[str]): The list of in-Socket (names) combinations starting with the combinations with the most Socket names, then going towards combinations with only one Socket name. Each combination in itself should already be sorted alphabetically on the in-Socket names. fetch_list (list): Appends to this list, which ops to actually fetch. input_dict (Optional[dict]): Dict specifying the provided inputs for some (core) in-Sockets. Passed through directly from the call method. feed_dict (dict): The feed_dict we are trying to build. When done, needs to map input ops (not Socket names) to data. Returns: tuple: fetch_list, feed-dict with relevant args. """ if len(input_combinations) > 0: # Check all (input+shape)-combinations and it we find one that matches what the user passed in as # `input_dict` -> Take that one and move on to the next Socket by returning. for input_combination in input_combinations: # Get all Space-combinations (in-op) for this input combination # (OBSOLETE: not possible anymore: in case an in-Socket has more than one connected incoming Spaces). ops = [self.in_socket_registry[c] for c in input_combination] # Get the shapes for this op_combination. shapes = tuple(get_shape(op) for op in ops) key = (socket_name, input_combination, shapes) # This is a good combination -> Use the looked up op, return to process next out-Socket. if key in self.call_registry: fetch_list.append(self.call_registry[key]) # Add items to feed_dict. for in_sock_name, in_op in zip(input_combination, ops): value = input_dict[in_sock_name] # Numpy'ize scalar values (tf doesn't sometimes like python primitives). if isinstance(value, (float, int, bool)): value = np.array(value) feed_dict[in_op] = value return fetch_list, feed_dict # No inputs -> Try whether this output socket comes without any inputs. else: key = (socket_name, (), ()) if key in self.call_registry: fetch_list.append(self.call_registry[key]) return fetch_list, feed_dict required_inputs = [ k[1] for k in self.call_registry.keys() if k[0] == socket_name ] raise YARLError( "ERROR: No op found for out-Socket '{}' given the input-combinations: {}! " "The following input-combinations are required for '{}':\n" "{}".format(socket_name, input_combinations, socket_name, required_inputs)) def trace_back_sockets(self, trace_set): """ For a set of given ops, returns a list of all (core) in-Sockets that are required to calculate these ops. Args: trace_set (Set[Union[DataOpRecords,Socket]]): The set of DataOpRecord/Socket objects to trace-back till the beginning of the Graph. Socket entries mean we have already reached the beginning of the Graph and these will no further be traced back. Returns: Set[Socket]: in-Socket objects (from the core Component) that are required to calculate the DataOps in `trace_set`. """ # Recursively lookup op in op_record_registry until we hit a Socket. new_trace_set = set() for op_rec_or_socket in trace_set: # We hit a Socket (we reached the beginning of the Graph). Stop tracing further back. if isinstance(op_rec_or_socket, Socket): if op_rec_or_socket.name not in self.in_socket_registry: raise YARLError( "ERROR: in-Socket '{}' could not be found in in_socket_registry of " "model!".format(op_rec_or_socket.name)) new_trace_set.add(op_rec_or_socket) # A DataOpRecord: Sanity check that we already have this. elif op_rec_or_socket not in self.op_record_registry: # Could be a DataOpRecord of a SingleDataOp with constant_value set. if not isinstance( op_rec_or_socket.op, SingleDataOp ) or op_rec_or_socket.op.constant_value is None: raise YARLError( "ERROR: DataOpRecord for op '{}' could not be found in op_record_registry of " "model!".format(op_rec_or_socket.op)) else: new_trace_set.update(self.op_record_registry[op_rec_or_socket]) if all([isinstance(i, Socket) for i in new_trace_set]): return new_trace_set else: return self.trace_back_sockets(new_trace_set) def get_default_model(self): """ Fetches the initially created default container. Returns: Component: The core container component. """ return self.core_component