def create_colocated_agents(agent_config, num_agents, max_attempts=10): """ Creates a specified number of co-located RayAgent workers. Args: agent_config (dict): Agent spec for worker agents. num_agents (int): Number of worker agents to create. max_attempts Optional[int]: Max number of attempts to create colocated agents, will raise an error if creation was not successful within this number. Returns: list: List of created agents. Raises: YARL-Error if not enough agents could be created within the specified number of attempts. """ agents = [] attempt = 1 while len(agents) < num_agents and attempt <= max_attempts: ray_agents = [ RayAgent.remote(agent_config) for _ in range(attempt * num_agents) ] local_agents, _ = split_local_non_local_agents(ray_agents) agents.extend(local_agents) if len(agents) < num_agents: raise YARLError( "Could not create the specified number ({}) of agents.".format( num_agents)) return agents[:num_agents]
def translate_space(space): """ Translates openAI spaces into YARL Space classes. Args: space (gym.spaces.Space): The openAI Space to be translated. Returns: Space: The translated yarl Space. """ if isinstance(space, gym.spaces.Discrete): return IntBox(space.n) elif isinstance(space, gym.spaces.MultiBinary): return BoolBox(shape=(space.n, )) elif isinstance(space, gym.spaces.MultiDiscrete): return IntBox(low=np.zeros((space.nvec.ndim, ), dtype("uint8", "np")), high=space.nvec) elif isinstance(space, gym.spaces.Box): return FloatBox(low=space.low, high=space.high) elif isinstance(space, gym.spaces.Tuple): return Tuple( *[OpenAIGymEnv.translate_space(s) for s in space.spaces]) elif isinstance(space, gym.spaces.Dict): return Dict({ k: OpenAIGymEnv.translate_space(v) for k, v in space.spaces.items() }) else: raise YARLError("Unknown openAI gym Space class for state_space!")
def get_tensor_variable(self, name, is_input_feed=False, add_batch_rank=None, **kwargs): add_batch_rank = self.has_batch_rank if add_batch_rank is None else add_batch_rank batch_rank = () if add_batch_rank is False else ( None, ) if add_batch_rank is True else (add_batch_rank, ) shape = tuple(batch_rank + self.shape) if get_backend() == "tf": import tensorflow as tf # TODO: re-evaluate the cutting of a leading '/_?' (tf doesn't like it) name = re.sub(r'^/_?', "", name) if is_input_feed: return tf.placeholder(dtype=dtype(self.dtype), shape=shape, name=name) else: # TODO: what about initializer spec? yarl_initializer = Initializer.from_spec( shape=shape, specification=kwargs.pop("initializer", None)) return tf.get_variable( name, shape=shape, dtype=dtype(self.dtype), initializer=yarl_initializer.initializer, **kwargs) else: raise YARLError("ERROR: Pytorch not supported yet!")
def get_activation_function(activation_function=None, *other_parameters): """ Returns an activation function (callable) to use in a NN layer. Args: activation_function (Optional[callable,str]): The activation function to lookup. Could be given as: - already a callable (return just that) - a lookup key (str) - None: Use linear activation. other_parameters (any): Possible extra parameter(s) used for some of the activation functions. Returns: callable: The backend-dependent activation function. """ if activation_function is None or callable(activation_function): return activation_function elif activation_function == "linear": return tf.identity # Rectifier linear unit (ReLU) : 0 if x < 0 else x elif activation_function == "relu": return tf.nn.relu # Exponential linear: exp(x) - 1 if x < 0 else x elif activation_function == "elu": return tf.nn.elu # Sigmoid: 1 / (1 + exp(-x)) elif activation_function == "sigmoid": return tf.sigmoid # Scaled exponential linear unit: scale * [alpha * (exp(x) - 1) if < 0 else x] # https://arxiv.org/pdf/1706.02515.pdf elif activation_function == "selu": return tf.nn.selu # Swish function: x * sigmoid(x) # https://arxiv.org/abs/1710.05941 elif activation_function == "swish": return lambda x: x * tf.sigmoid(x=x) # Leaky ReLU: x * [alpha if x < 0 else 1.0] elif activation_function in ["lrelu", "leaky_relu"]: alpha = other_parameters[0] if len(other_parameters) > 0 else 0.2 return partial(tf.nn.leaky_relu, alpha=alpha) # Concatenated ReLU: elif activation_function == "crelu": return tf.nn.crelu # Softmax function: elif activation_function == "softmax": return tf.nn.softmax # Softplus function: elif activation_function == 'softplus': return tf.nn.softplus # Softsign function: elif activation_function == "softsign": return tf.nn.softsign # tanh activation function: elif activation_function == "tanh": return tf.nn.tanh else: raise YARLError("ERROR: Unknown activation_function '{}'!".format( activation_function))
def sanity_check_meta_graph(self, component=None): """ Checks whether all the `component`'s and its sub-components' in-Sockets are simply connected in the meta-graph and raises detailed error messages if not. A connection to an in-Socket is ok if ... a) it's coming from another Socket or b) it's coming from a Space object Args: component (Component): The Component to analyze for incoming connections. """ component = component or self.core_component if self.logger.level <= logging.INFO: component_print_out(component) # Check all the Component's in-Sockets for being connected from a Space/Socket. for in_sock in component.input_sockets: # type: Socket if len(in_sock.incoming_connections) == 0 and \ in_sock.name not in component.unconnected_sockets_in_meta_graph: raise YARLError( "Component '{}' has in-Socket ({}) without any incoming connections! If this is " "intended before the build process, you have to add the Socket's name to the " "Component's `unconnected_sockets_in_meta_graph` set. Then this error will be " "suppressed for this Component.".format( component.name, in_sock.name)) # Check all the component's graph_fns for input-completeness. for graph_fn in component.graph_fns: # type: GraphFunction for in_sock_rec in graph_fn.input_sockets.values(): in_sock = in_sock_rec["socket"] if len(in_sock.incoming_connections) == 0 and \ in_sock.name not in component.unconnected_sockets_in_meta_graph: raise YARLError( "GraphFn {}/{} has in-Socket ({}) without any incoming " "connections!".format(component.name, graph_fn.name, in_sock_rec["socket"].name)) # Recursively call this method on all the sub-component's sub-components. for sub_component in component.sub_components.values(): self.build_steps += 1 if self.build_steps >= self.MAX_RECURSIVE_CALLS: raise YARLError( "Error sanity checking graph, reached max recursion steps: {}" .format(self.MAX_RECURSIVE_CALLS)) self.sanity_check_meta_graph(sub_component)
def mapping_func(key, space): if isinstance(space, IntBox): # Must have global bounds (bounds valid for all axes). if space.num_categories is False: raise YARLError("ERROR: Cannot flatten categories if one of the IntBox spaces ({}={}) does not " "have global bounds (its `num_categories` is False)!".format(key, space)) return space.num_categories # No categories. Keep as is. return 1
def sanity_check_build(self, component=None): """ Checks whether all the `component`'s and sub-components's in-Sockets and graph_fns are input-complete and raises detailed error messages if not. Input-completeness means that .. a) all in-Sockets of a Component or b) all connected incoming Sockets to a GraphFunction .. have their `self.space` field defined (is not None). Args: component (Component): The Component to analyze for input-completeness. """ component = component or self.core_component if self.logger.level <= logging.INFO: component_print_out(component) # Check all the component's graph_fns for input-completeness. for graph_fn in component.graph_fns: if graph_fn.input_complete is False: # Look for the missing in-Socket and raise an Error. for in_sock_name, in_sock_record in graph_fn.input_sockets.items( ): if len(in_sock_record["socket"].op_records) == 0 and \ in_sock_name not in component.unconnected_sockets_in_meta_graph: raise YARLError( "in-Socket '{}' of GraphFunction '{}' of Component '{}' does not have " "any incoming ops!".format(in_sock_name, graph_fn.name, component.global_scope)) # Check component's sub-components for input-completeness (recursively). for sub_component in component.sub_components.values( ): # type: Component if sub_component.input_complete is False: # Look for the missing Socket and raise an Error. for in_sock in sub_component.input_sockets: if in_sock.space is None: raise YARLError("Component '{}' is not input-complete. In-Socket '{}' does not " \ "have any incoming connections.". format(sub_component.global_scope, in_sock.name)) # Recursively call this method on all the sub-component's sub-components. self.sanity_check_build(sub_component)
def _get_execution_inputs_for_socket(self, socket_name, input_combinations, fetch_list, input_dict, feed_dict): """ Helper (to avoid nested for loop-break) for the loop in get_execution_inputs. Args: socket_name (str): The name of the (core) out-Socket to process. input_combinations (List[str]): The list of in-Socket (names) combinations starting with the combinations with the most Socket names, then going towards combinations with only one Socket name. Each combination in itself should already be sorted alphabetically on the in-Socket names. fetch_list (list): Appends to this list, which ops to actually fetch. input_dict (Optional[dict]): Dict specifying the provided inputs for some (core) in-Sockets. Passed through directly from the call method. feed_dict (dict): The feed_dict we are trying to build. When done, needs to map input ops (not Socket names) to data. Returns: tuple: fetch_list, feed-dict with relevant args. """ if len(input_combinations) > 0: # Check all (input+shape)-combinations and it we find one that matches what the user passed in as # `input_dict` -> Take that one and move on to the next Socket by returning. for input_combination in input_combinations: # Get all Space-combinations (in-op) for this input combination # (OBSOLETE: not possible anymore: in case an in-Socket has more than one connected incoming Spaces). ops = [self.in_socket_registry[c] for c in input_combination] # Get the shapes for this op_combination. shapes = tuple(get_shape(op) for op in ops) key = (socket_name, input_combination, shapes) # This is a good combination -> Use the looked up op, return to process next out-Socket. if key in self.call_registry: fetch_list.append(self.call_registry[key]) # Add items to feed_dict. for in_sock_name, in_op in zip(input_combination, ops): value = input_dict[in_sock_name] # Numpy'ize scalar values (tf doesn't sometimes like python primitives). if isinstance(value, (float, int, bool)): value = np.array(value) feed_dict[in_op] = value return fetch_list, feed_dict # No inputs -> Try whether this output socket comes without any inputs. else: key = (socket_name, (), ()) if key in self.call_registry: fetch_list.append(self.call_registry[key]) return fetch_list, feed_dict required_inputs = [ k[1] for k in self.call_registry.keys() if k[0] == socket_name ] raise YARLError( "ERROR: No op found for out-Socket '{}' given the input-combinations: {}! " "The following input-combinations are required for '{}':\n" "{}".format(socket_name, input_combinations, socket_name, required_inputs))
def _graph_fn_sync(self, values_): """ Generates the op that syncs this Synchronizable's parent's variable values from another Synchronizable Component. Args: values_ (DataOpDict): The dict of variable values (coming from the "_variables"-Socket of any other Component) that need to be assigned to this Component's parent's variables. The keys in the dict refer to the names of our parent's variables and must match their names. Returns: DataOp: The op that executes the syncing. """ # Loop through all incoming vars and our own and collect assign ops. syncs = list() parents_vars = self.parent_component.get_variables( collections=self.collections, custom_scope_separator="-") # Sanity checking syncs_from, syncs_to = (values_.items(), parents_vars.items()) if len(syncs_from) != len(syncs_to): raise YARLError("ERROR: Number of Variables to sync must match! " "We have {} syncs_from and {} syncs_to.".format( len(syncs_from), len(syncs_to))) for (key_from, var_from), (key_to, var_to) in zip(syncs_from, syncs_to): # Sanity checking. TODO: Check the names' ends? Without the global scope? #if key_from != key_to: # raise YARLError("ERROR: Variable names for syncing must match in order and name! " # "Mismatch at from={} and to={}.".format(key_from, key_to)) if get_shape(var_from) != get_shape(var_to): raise YARLError( "ERROR: Variable shapes for syncing must match! " "Shape mismatch between from={} ({}) and to={} ({}).". format(key_from, get_shape(var_from), key_to, get_shape(var_to))) syncs.append(self.assign_variable(var_to, var_from)) # Bundle everything into one "sync"-op. if get_backend() == "tf": with tf.control_dependencies(syncs): return tf.no_op()
def __init__(self, spec=None, add_batch_rank=False, **kwargs): ContainerSpace.__init__(self, add_batch_rank=add_batch_rank) # Allow for any spec or already constructed Space to be passed in as values in the python-dict. # Spec may be part of kwargs. if spec is None: spec = kwargs dict_ = dict() for key in sorted(spec.keys()): # Keys must be strings. if not isinstance(key, str): raise YARLError( "ERROR: No non-str keys allowed in a Dict-Space!") # Prohibit reserved characters (for flattened syntax). if re.search( r'/|{}\d+{}'.format(FLAT_TUPLE_OPEN, FLAT_TUPLE_CLOSE), key): raise YARLError( "ERROR: Key to Dict must not contain '/' or '{}\d+{}'! Is {}." .format(FLAT_TUPLE_OPEN, FLAT_TUPLE_CLOSE, key)) value = spec[key] # Value is already a Space: Copy it (to not affect original Space) and maybe add/remove batch-rank. if isinstance(value, Space): dict_[key] = value.with_batch_rank(add_batch_rank) # Value is a list/tuple -> treat as Tuple space. elif isinstance(value, (list, tuple)): dict_[key] = Tuple(*value, add_batch_rank=add_batch_rank) # Value is a spec (or a spec-dict with "type" field) -> produce via `from_spec`. elif (isinstance(value, dict) and "type" in value) or not isinstance(value, dict): dict_[key] = Space.from_spec(value, add_batch_rank=add_batch_rank) # Value is a simple dict -> recursively construct another Dict Space as a sub-space of this one. else: dict_[key] = Dict(value, add_batch_rank=add_batch_rank) # Removed this restriction. Sometimes, we need empty Variables dicts. #if len(dict_) == 0: # raise YARLError("ERROR: Dict() c'tor needs a non-empty spec!") dict.__init__(self, dict_)
def trace_back_sockets(self, trace_set): """ For a set of given ops, returns a list of all (core) in-Sockets that are required to calculate these ops. Args: trace_set (Set[Union[DataOpRecords,Socket]]): The set of DataOpRecord/Socket objects to trace-back till the beginning of the Graph. Socket entries mean we have already reached the beginning of the Graph and these will no further be traced back. Returns: Set[Socket]: in-Socket objects (from the core Component) that are required to calculate the DataOps in `trace_set`. """ # Recursively lookup op in op_record_registry until we hit a Socket. new_trace_set = set() for op_rec_or_socket in trace_set: # We hit a Socket (we reached the beginning of the Graph). Stop tracing further back. if isinstance(op_rec_or_socket, Socket): if op_rec_or_socket.name not in self.in_socket_registry: raise YARLError( "ERROR: in-Socket '{}' could not be found in in_socket_registry of " "model!".format(op_rec_or_socket.name)) new_trace_set.add(op_rec_or_socket) # A DataOpRecord: Sanity check that we already have this. elif op_rec_or_socket not in self.op_record_registry: # Could be a DataOpRecord of a SingleDataOp with constant_value set. if not isinstance( op_rec_or_socket.op, SingleDataOp ) or op_rec_or_socket.op.constant_value is None: raise YARLError( "ERROR: DataOpRecord for op '{}' could not be found in op_record_registry of " "model!".format(op_rec_or_socket.op)) else: new_trace_set.update(self.op_record_registry[op_rec_or_socket]) if all([isinstance(i, Socket) for i in new_trace_set]): return new_trace_set else: return self.trace_back_sockets(new_trace_set)
def check_input_spaces(self, input_spaces, action_space): # The Distribution to sample (or pick) actions from. # Discrete action space -> Categorical distribution (each action needs a logit from network). if isinstance(action_space, IntBox): self.distribution = Categorical() # Continuous action space -> Normal distribution (each action needs mean and variance from network). elif isinstance(action_space, FloatBox): self.distribution = Normal() else: raise YARLError( "ERROR: Space of out-Socket `action` is of type {} and not allowed in {} Component!" .format(type(action_space).__name__, self.name)) # This defines out-Sockets "sample_stochastic/sample_deterministic/entropy". self.add_component(self.distribution, connections=CONNECT_OUTS) # Plug-in Adapter Component into Distribution. self.connect((self.action_adapter, "parameters"), (self.distribution, "parameters"))
def check_input_spaces(self, input_spaces, action_space): self.action_space = action_space.with_batch_rank() assert self.action_space.has_batch_rank, "ERROR: `self.action_space` does not have batch rank!" if self.epsilon_exploration and self.noise_component: # Check again at graph creation? This is currently redundant to the check in __init__ raise YARLError( "Cannot use both epsilon exploration and a noise component at the same time." ) if self.epsilon_exploration: # Currently only IntBox is allowed for discrete action spaces assert isinstance(self.action_space, IntBox), "Only IntBox Spaces are currently supported " \ "for exploration components." assert self.action_space.num_categories is not None and self.action_space.num_categories > 0, \ "ERROR: `action_space` must have `num_categories` defined and > 0!" elif self.noise_component: assert isinstance(self.action_space, FloatBox), "Only FloatBox spaces are currently supported " \ "for noise components."
def push_from_socket(self, socket): # Skip this Socket, if it doesn't have a Space (no incoming connection). # Assert that it's ok for the component to leave this Socket open. if socket.space is None: assert socket.name in socket.component.unconnected_sockets_in_meta_graph return for outgoing in socket.outgoing_connections: # Push Socket into Socket. if isinstance(outgoing, Socket): print("SOCK {}/{} -> {}/{}".format(socket.component.name, socket.name, outgoing.component.name, outgoing.name)) self.push_socket_into_socket(socket, outgoing) # Push Socket into GraphFunction. elif isinstance(outgoing, GraphFunction): self.push_from_graph_fn(outgoing) # Error. else: raise YARLError("ERROR: Outgoing connection ({}) must be of type Socket or GraphFunction!".\ format(outgoing))
def __init__(self, num_iterations, call_component, graph_fn_name, scope="fixed-loop", **kwargs): """ Args: num_iterations (int): How often to call the given GraphFn. call_component (Component): Component providing graph fn to call within loop. graph_fn_name (str): The name of the graph_fn in call_component. """ assert num_iterations > 0 super(FixedLoop, self).__init__(scope=scope, **kwargs) self.num_iterations = num_iterations self.graph_fn_to_call = None flatten_ops = False for graph_fn in call_component.graph_fns: if graph_fn.name == graph_fn_name: self.graph_fn_to_call = graph_fn.get_method() flatten_ops = graph_fn.flatten_ops break if not self.graph_fn_to_call: raise YARLError( "ERROR: GraphFn '{}' not found in Component '{}'!".format( graph_fn_name, call_component.global_scope)) # TODO: Do we sum up, append to list, ...? self.define_inputs("inputs") self.define_outputs("fixed_loop_result") self.add_component(call_component) self.add_graph_fn("inputs", "fixed_loop_result", self._graph_fn_call_loop, flatten_ops={"inputs"} if flatten_ops else None)
def parse_update_spec(update_spec): """ Parses update/learning parameters and inserts default values where necessary. Args: update_spec (Optional[dict]): Update/Learning spec dict. Returns: dict: The sanitized update_spec dict. """ # If no spec given. default_spec = dict( # Whether to perform calls to `Agent.update()` at all. do_updates=True, # The unit in which we measure frequency: one of "timesteps", "episodes", "sec". #unit="timesteps", # TODO: not supporting any other than timesteps # The number of 'units' to wait before we do any updating at all. steps_before_update=0, # The frequency with which we update (given in `unit`). update_interval=4, # The number of consecutive `Agent.update()` calls per update. update_steps=1, # The batch size with which to update (e.g. when pulling records from a memory). batch_size=64, sync_interval=128 ) update_spec = default_dict(update_spec, default_spec) # Assert that the synch interval is a multiple of the update_interval. if update_spec["sync_interval"] / update_spec["update_interval"] != \ update_spec["sync_interval"] // update_spec["update_interval"]: raise YARLError( "ERROR: sync_interval ({}) must be multiple of update_interval " "({})!".format(update_spec["sync_interval"], update_spec["update_interval"]) ) return update_spec
def run_through_graph_fn(self, graph_fn): """ Pushes all incoming ops through the method of this GraphFunction object. The ops are collected from incoming Sockets and optionally flattened and/or split before pushing them through the method and the return values optionally unflattened. Args: graph_fn (GraphFunction): The GraphFunction object to run through (its method) with all possible in-Socket combinations (only those that have not run yet through the method). """ in_op_records = [ in_sock_rec["socket"].op_records for in_sock_rec in graph_fn.input_sockets.values() ] in_op_records_combinations = list(itertools.product(*in_op_records)) for in_op_record_combination in in_op_records_combinations: # Make sure we call the computation method only once per input-op combination. if in_op_record_combination in graph_fn.in_out_records_map: continue # Replace constant-value Sockets with their SingleDataOp's constant numpy values # and the DataOps with their actual ops (`op` property of DataOp). actual_call_params = [ op_rec.op.constant_value if isinstance(op_rec.op, SingleDataOp) and op_rec.op.constant_value is not None else op_rec.op for op_rec in in_op_record_combination ] # Build the ops from this input-combination. # Flatten input items. if graph_fn.flatten_ops is not False: flattened_ops = graph_fn.flatten_input_ops(*actual_call_params) # Split into SingleDataOps? if graph_fn.split_ops: call_params = split_flattened_input_ops( graph_fn.add_auto_key_as_first_param, *flattened_ops) # There is some splitting to do. Call graph_fn many times (one for each split). if isinstance(call_params, FlattenedDataOp): ops = dict() num_return_values = -1 for key, params in call_params.items(): ops[key] = force_tuple(graph_fn.method(*params)) if num_return_values >= 0 and num_return_values != len( ops[key]): raise YARLError( "Different split-runs through {} do not return the same number of " "values!".format(graph_fn.name)) num_return_values = len(ops[key]) # Un-split the results dict into a tuple of `num_return_values` slots. un_split_ops = list() for i in range(num_return_values): dict_with_singles = FlattenedDataOp() for key in call_params.keys(): dict_with_singles[key] = ops[key][i] un_split_ops.append(dict_with_singles) ops = tuple(un_split_ops) # No splitting to do: Pass everything as-is. else: ops = graph_fn.method(*call_params) else: ops = graph_fn.method(*flattened_ops) # Just pass in everything as-is. else: ops = graph_fn.method(*actual_call_params) # OBSOLETE: always must un-flatten all return values. Otherwise, we would allow Dict Spaces # with '/' keys in them, which is not allowed. #if graph_fn.unflatten_ops: ops = graph_fn.unflatten_output_ops(*force_tuple(ops)) # Make sure everything coming from a computation is always a tuple (for out-Socket indexing). ops = force_tuple(ops) # Make sure the number of returned ops matches the number of outgoing Sockets from thie graph_fn assert len(ops) == len(graph_fn.output_sockets),\ "ERROR: Number of returned values of graph_fn '{}/{}' ({}) does not match number of out-Sockets ({}) " \ "of this GraphFunction!".format(graph_fn.component.name, graph_fn.name, len(ops), len(graph_fn.output_sockets)) # ops are now the raw graph_fn output: Need to convert it back to records. new_label_set = set() for rec in in_op_record_combination: # type: DataOpRecord new_label_set.update(rec.labels) op_records = convert_ops_to_op_records(ops, labels=new_label_set) graph_fn.in_out_records_map[in_op_record_combination] = op_records # Move graph_fn results into next Socket(s). for i, (socket, op_rec) in enumerate( zip(graph_fn.output_sockets, op_records)): self.logger.debug( "GraphFn {}/{} -> return-slot {} -> {} -> Socket {}/{}". format(graph_fn.component.name, graph_fn.name, i, ops, socket.component.name, socket.name)) # Store op_rec in the respective outgoing Socket (and make sure Spaces match). space = get_space_from_op(op_rec.op) if len(socket.op_records) > 0: sanity_check_space = get_space_from_op( next(iter(socket.op_records)).op) assert space == sanity_check_space,\ "ERROR: Newly calculated output op of graph_fn '{}' has different Space than existing one " \ "({} vs {})!".format(graph_fn.name, space, sanity_check_space) else: socket.space = space socket.op_records.add(op_rec) self.op_record_registry[op_rec] = set(in_op_record_combination) # Make sure all op_records do not contain SingleDataOps with constant_values. Any # in-Socket-connected constant values need to be converted to actual ops during a graph_fn call. assert not isinstance(op_rec.op, SingleDataOp), \ "ERROR: graph_fn '{}' returned a SingleDataOp with constant_value set to '{}'! " \ "This is not allowed. All graph_fns must return actual (non-constant) ops.". \ format(graph_fn.name, op_rec.op.constant_value)
def get_graph_markup(component, level=0, draw_graph_fns=False): """ Returns graph markup to be used for YARL metagraph plotting. Uses the (mermaid)[https://github.com/knsv/mermaid] markup language. Args: component (Component): Component to generate meta-graph markup for. level (int): Indentation level. If >= 1, return this component as sub-component. draw_graph_fns (bool): Include graph fns in plot. Returns: str: Meta-graph markup string. """ # Print (sub)graph declaration if level >= 1: markup = " " * 4 * level + "subgraph {}\n".format(component.name) elif level == 0: markup = "graph TD\n" markup += "classDef input_socket fill:#9ff,stroke:#333,stroke-width:2px;\n" markup += "classDef output_socket fill:#f9f,stroke:#333,stroke-width:2px;\n" markup += "classDef space fill:#999,stroke:#333,stroke-width:2px;\n" markup += "classDef graph_fn fill:#ff9,stroke:#333,stroke-width:2px;\n" markup += "\n" else: raise YARLError("Invalid component indentation level {}".format(level)) all_sockets = list() all_graph_fns = list() # Add input socket nodes with the following markup: in_socket_HASH(INPUT SOCKET NAME) markup_input_sockets = list() for input_socket in component.input_sockets: markup += " " * 4 * (level + 1) + "socket_{hash}(\"{name}\")\n".format(hash=hash(input_socket), name=input_socket.name) markup_input_sockets.append("socket_{hash}".format(hash=hash(input_socket))) all_sockets.append(input_socket) # Add output socket nodes with the following markup: out_socket_HASH(OUTPUT SOCKET NAME) markup_output_sockets = list() for output_socket in component.output_sockets: markup += " " * 4 * (level + 1) + "socket_{hash}(\"{name}\")\n".format(hash=hash(output_socket), name=output_socket.name) markup_output_sockets.append("socket_{hash}".format(hash=hash(output_socket))) all_sockets.append(output_socket) markup += "\n" # Add graph function nodes with the following markup: graphfn_HASH(GRAPH FN NAME) markup_graph_fns = list() for graph_fn in component.graph_fns: markup += " " * 4 * (level + 1) + "graphfn_{hash}(\"{name}\")\n".format(hash=hash(graph_fn), name=graph_fn.name) markup_graph_fns.append("graphfn_{hash}".format(hash=hash(graph_fn))) all_graph_fns.append(graph_fn) # Collect connections by looping through all incoming connections. # All outgoing connections should be incoming connections of another socket, so we don't need to loop through them. connections = list() markup_spaces = list() for socket in all_sockets: for incoming_connection in socket.incoming_connections: if isinstance(incoming_connection, Socket): connections.append(( "socket_{}".format(hash(incoming_connection)), "socket_{}".format(hash(socket)), None )) elif isinstance(incoming_connection, Space): # Add spaces to markup (we only know about them because of their connections). markup += " " * 4 * (level + 1) + "space_{hash}(\"{name}\")\n".format(hash=hash(incoming_connection), name=str(incoming_connection)) markup_spaces.append("space_{hash}".format(hash=hash(incoming_connection))) connections.append(( "space_{}".format(hash(incoming_connection)), "socket_{}".format(hash(socket)), None )) elif isinstance(incoming_connection, GraphFunction): connections.append(( "graphfn_{}".format(hash(incoming_connection)), "socket_{}".format(hash(socket)), None )) # Collect graph fn connections by looping through all input sockets of the graph fns. # All output sockets should have been covered by the above collection of incoming connections to the sockets. for graph_fn in all_graph_fns: for input_socket_name, input_socket_dict in graph_fn.input_sockets.items(): input_socket = input_socket_dict['socket'] if isinstance(input_socket, Socket): connections.append(( "socket_{}".format(hash(input_socket)), "graphfn_{}".format(hash(graph_fn)), None )) else: raise ValueError("Not a valid input socket: {} ({})".format(input_socket, type(input_socket))) # Add style class `input_socket` to the input sockets if markup_input_sockets: markup += " " * 4 * (level + 1) + "class {} input_socket;\n".format(','.join(markup_input_sockets)) # Add style class `output_socket` to the output sockets if markup_output_sockets: markup += " " * 4 * (level + 1) + "class {} output_socket;\n".format(','.join(markup_output_sockets)) # Add style class `space` to the spaces if markup_spaces: markup += " " * 4 * (level + 1) + "class {} space;\n".format(','.join(markup_spaces)) # Add style class `graph_fn` to the graph fns if markup_graph_fns: markup += " " * 4 * (level + 1) + "class {} graph_fn;\n".format(','.join(markup_graph_fns)) markup += "\n" # Add sub-components. for sub_component_name, sub_component in component.sub_components.items(): markup += get_graph_markup(sub_component, level=level + 1, draw_graph_fns=draw_graph_fns) # Subgraphs (level >= 1) require an end statement. if level >= 1: markup += " " * 4 * level + "end\n" markup += "\n" # Connection are inserted after the graph for connection in connections: if connection[2]: # Labeled connection markup += " " * 4 * level + "{}--{}-->{}\n".format(connection[0], connection[2], connection[1]) else: # Unlabeled connection markup += " " * 4 * level + "{}-->{}\n".format(connection[0], connection[1]) markup += "\n" return markup
def __init__(self, non_explore_behavior="max-likelihood", epsilon_spec=None, noise_spec=None, scope="exploration", **kwargs): """ Args: non_explore_behavior (str): One of: max-likelihood: When not exploring, pick an action deterministically (max-likelihood) from the Policy's distribution. sample: When not exploring, pick an action stochastically according to the Policy's distribution. random: When not exploring, pick an action randomly. epsilon_spec (any): The spec or Component object itself to construct an EpsilonExploration Component. noise_spec (dict): The specification dict for a noise generator that adds noise to the NN's output. """ super(Exploration, self).__init__(scope=scope, **kwargs) self.action_space = None self.non_explore_behavior = non_explore_behavior # Define our interface. self.define_inputs("sample_deterministic", "sample_stochastic", "time_step") self.define_outputs("action", "do_explore", "noise") self.epsilon_exploration = None self.noise_component = None # Don't allow both epsilon and noise component if epsilon_spec and noise_spec: raise YARLError( "Cannot use both epsilon exploration and a noise component at the same time." ) # Add epsilon component if epsilon_spec: self.epsilon_exploration = EpsilonExploration.from_spec( epsilon_spec) self.add_component(self.epsilon_exploration) self.connect("time_step", (self.epsilon_exploration, "time_step")) self.connect((self.epsilon_exploration, "do_explore"), "do_explore") # Add our own graph_fn and connect its output to the "action" Socket. self.add_graph_fn(inputs=[(self.epsilon_exploration, "do_explore"), "sample_deterministic", "sample_stochastic"], outputs="action", method=self._graph_fn_pick) # Add noise component elif noise_spec: # Currently no noise component uses the time_step variable self.unconnected_sockets_in_meta_graph.add("time_step") self.noise_component = NoiseComponent.from_spec(noise_spec) self.add_component(self.noise_component) self.connect((self.noise_component, "noise"), "noise") self.add_graph_fn(inputs=[(self.noise_component, "noise"), "sample_deterministic", "sample_stochastic"], outputs="action", method=self._graph_fn_add_noise) # Don't explore at all else: if self.non_explore_behavior == "max-likelihood": self.connect("sample_deterministic", "action") else: self.connect("sample_stochastic", "action")
def get_execution_inputs(self, output_socket_names, inputs=None): """ Fetches graph inputs for execution. Args: output_socket_names (Union[str,List[str]]): A name or a list of names of the out-Sockets to fetch from our core component. inputs (Optional[dict,data]): Dict specifying the provided inputs for some in-Sockets. Depending on these given inputs, the correct backend-ops can be selected within the given (out)-Sockets. Alternatively, can pass in data directly (not as a dict), but only if there is only one in-Socket in the Model or only one of the in-Sockets is needed for the given out-Sockets. Returns: tuple: fetch-dict, feed-dict with relevant args. 9 """ output_socket_names = force_list(output_socket_names) # Sanity check out-Socket names. for out_sock_name in output_socket_names: if out_sock_name not in self.out_socket_registry: raise YARLError( "ERROR: Out-Socket '{}' not found in Model! Make sure you are fetching by the \n" "correct out-Socket name.".format(out_sock_name)) only_input_socket_name = None # the name of the only in-Socket possible here # Some input is given. if inputs is not None: # Get only in-Socket .. if len(self.core_component.input_sockets) == 1: only_input_socket_name = self.core_component.input_sockets[ 0].name # .. or only in-Socket for single(!), given out-Socket. elif len(output_socket_names) == 1 and \ len(self.out_socket_registry[output_socket_names[0]]) == 1: only_input_socket_name = next( iter(self.out_socket_registry[output_socket_names[0]])) # Check whether data is given directly. if not isinstance(inputs, dict): if only_input_socket_name is None: raise YARLError( "ERROR: Input data (`inputs`) given directly (not as dict) AND more than one \n" "in-Socket in Model OR more than one in-Socket needed for given out-Sockets '{}'!" .format(output_socket_names)) inputs = {only_input_socket_name: inputs} # Is a dict: Check whether it's a in-Socket name dict (leave as is) or a # data dict (add in-Socket name as key). else: # We have more than one necessary in-Sockets (leave as is) OR # the only necessary in-Socket name is not key of the dict -> wrap it. if only_input_socket_name is not None and only_input_socket_name not in inputs: inputs = {only_input_socket_name: inputs} # Try all possible input combinations to see whether we got an op for that. # Input Socket names will be sorted alphabetically and combined from short sequences up to longer ones. # Example: inputs={A: ..., B: ... C: ...} # input_combinations=[ABC, AB, AC, BC, A, B, C] # These combinations have been memoized for fast lookup. key = tuple(sorted(inputs.keys())) input_combinations = self.input_combinations.get(key) if not input_combinations: raise YARLError( "ERROR: At least one of the given in-Socket names {} seems to be non-existent " "in Model!".format(key)) # No input given (maybe an out-Socket that doesn't require input). else: input_combinations = list(()) # Go through each (core) out-Socket names and collect the correct ops to go into the fetch_list. fetch_list = list() feed_dict = dict() for out_socket_name in output_socket_names: # Updates with relevant ops fetch_list, feed_dict = self._get_execution_inputs_for_socket( out_socket_name, input_combinations, fetch_list, inputs, feed_dict) return fetch_list, feed_dict