def __init__(self, layer_type, name=None, **kwargs): if not is_valid_layer_name(layer_type): raise NetworkValidationError( "Invalid layer_type: '{}'".format(layer_type)) if not (name is None or is_valid_layer_name(name)): raise NetworkValidationError( "Invalid name for layer: '{}'".format(name)) super(LayerDetails, self).__init__(name or layer_type) self.layer_type = layer_type """The type this layer should have when later being instantiated.""" self.incoming = [] """A list of all incoming connections, including input/output names. Each entry of the list has the form: (incoming_layer, output_name, input_name) and the type: tuple[LayerDetails, str, str] """ self.outgoing = [] """A list of all outgoing connections, including input/output names. Each entry of the list has the form: (output_name, input_name, outgoing_layer) and the type: tuple[str, str, LayerDetails] """ self.layer_kwargs = kwargs """Dictionary of additional parameters for this layer""" self._traversing = False
def get_regex_for_reference(reference): """ Return a corresponding regex for refs like: 'FooLayer', 'I*_bias', or even '*layer*'. """ if is_valid_layer_name(reference): return re.compile('^' + reference + '$') assert is_valid_layer_name(reference.replace('*', '_')), \ "{} is not a valid layer reference.".format(reference) return re.compile('^' + reference.replace('*', '[_a-zA-Z0-9]*') + '$')
def validate_architecture(architecture): # schema for name, layer in architecture.items(): if not isinstance(name, string_types): raise NetworkValidationError('Non-string name {}'.format(name)) if '@type' not in layer: raise NetworkValidationError( 'Missing @type for "{}"'.format(name)) if not isinstance(layer['@type'], string_types): raise NetworkValidationError('Invalid @type for "{}": {}'.format( name, type(layer['@type']))) if '@outgoing_connections' in layer and not isinstance( layer['@outgoing_connections'], (list, tuple, dict)): raise NetworkValidationError( 'Invalid @outgoing_connections for "{}"'.format(name)) # layer naming for name in architecture: if not is_valid_layer_name(name): raise NetworkValidationError( "Invalid layer name: '{}'".format(name)) # all outgoing connections are present connections = collect_all_connections(architecture) end_layers = {c.end_layer for c in connections} undefined_end_layers = end_layers.difference(architecture) if undefined_end_layers: raise NetworkValidationError( 'Could not find end layer(s) "{}"'.format(undefined_end_layers)) # has exactly one Input and its called Input if "Input" not in architecture or \ architecture['Input']['@type'] != 'Input': raise NetworkValidationError( 'Needs exactly one Input that is called "Input"') # no connections to Input if 'Input' in end_layers: raise NetworkValidationError( 'Input can not have incoming connections!') # TODO: check if connected # TODO: check for cycles return True