Пример #1
0
    def __init__(self, layer_type, name=None, **kwargs):
        if not is_valid_layer_name(layer_type):
            raise NetworkValidationError(
                "Invalid layer_type: '{}'".format(layer_type))
        if not (name is None or is_valid_layer_name(name)):
            raise NetworkValidationError(
                "Invalid name for layer: '{}'".format(name))
        super(LayerDetails, self).__init__(name or layer_type)

        self.layer_type = layer_type
        """The type this layer should have when later being instantiated."""

        self.incoming = []
        """A list of all incoming connections, including input/output names.

        Each entry of the list has the form:
        (incoming_layer, output_name, input_name)
        and the type:
        tuple[LayerDetails, str, str]
        """

        self.outgoing = []
        """A list of all outgoing connections, including input/output names.

        Each entry of the list has the form:
        (output_name, input_name, outgoing_layer)
        and the type:
        tuple[str, str, LayerDetails]
        """

        self.layer_kwargs = kwargs
        """Dictionary of additional parameters for this layer"""

        self._traversing = False
Пример #2
0
    def __init__(self, layer_type, name=None, **kwargs):
        if not is_valid_layer_name(layer_type):
            raise NetworkValidationError(
                "Invalid layer_type: '{}'".format(layer_type))
        if not (name is None or is_valid_layer_name(name)):
            raise NetworkValidationError(
                "Invalid name for layer: '{}'".format(name))
        super(LayerDetails, self).__init__(name or layer_type)

        self.layer_type = layer_type
        """The type this layer should have when later being instantiated."""

        self.incoming = []
        """A list of all incoming connections, including input/output names.

        Each entry of the list has the form:
        (incoming_layer, output_name, input_name)
        and the type:
        tuple[LayerDetails, str, str]
        """

        self.outgoing = []
        """A list of all outgoing connections, including input/output names.

        Each entry of the list has the form:
        (output_name, input_name, outgoing_layer)
        and the type:
        tuple[str, str, LayerDetails]
        """

        self.layer_kwargs = kwargs
        """Dictionary of additional parameters for this layer"""

        self._traversing = False
Пример #3
0
def get_regex_for_reference(reference):
    """
    Return a corresponding regex for refs like: 'FooLayer', 'I*_bias',
    or even '*layer*'.
    """
    if is_valid_layer_name(reference):
        return re.compile('^' + reference + '$')

    assert is_valid_layer_name(reference.replace('*', '_')), \
        "{} is not a valid layer reference.".format(reference)

    return re.compile('^' + reference.replace('*', '[_a-zA-Z0-9]*') + '$')
Пример #4
0
def get_regex_for_reference(reference):
    """
    Return a corresponding regex for refs like: 'FooLayer', 'I*_bias',
    or even '*layer*'.
    """
    if is_valid_layer_name(reference):
        return re.compile('^' + reference + '$')

    assert is_valid_layer_name(reference.replace('*', '_')), \
        "{} is not a valid layer reference.".format(reference)

    return re.compile('^' + reference.replace('*', '[_a-zA-Z0-9]*') + '$')
Пример #5
0
def validate_architecture(architecture):
    # schema
    for name, layer in architecture.items():
        if not isinstance(name, string_types):
            raise NetworkValidationError('Non-string name {}'.format(name))
        if '@type' not in layer:
            raise NetworkValidationError(
                'Missing @type for "{}"'.format(name))
        if not isinstance(layer['@type'], string_types):
            raise NetworkValidationError('Invalid @type for "{}": {}'.format(
                name, type(layer['@type'])))

        if '@outgoing_connections' in layer and not isinstance(
                layer['@outgoing_connections'], (list, tuple, dict)):
            raise NetworkValidationError(
                'Invalid @outgoing_connections for "{}"'.format(name))

    # layer naming
    for name in architecture:
        if not is_valid_layer_name(name):
            raise NetworkValidationError(
                "Invalid layer name: '{}'".format(name))

    # all outgoing connections are present
    connections = collect_all_connections(architecture)
    end_layers = {c.end_layer for c in connections}
    undefined_end_layers = end_layers.difference(architecture)
    if undefined_end_layers:
        raise NetworkValidationError(
            'Could not find end layer(s) "{}"'.format(undefined_end_layers))

    # has exactly one Input and its called Input
    if "Input" not in architecture or \
            architecture['Input']['@type'] != 'Input':
        raise NetworkValidationError(
            'Needs exactly one Input that is called "Input"')

    # no connections to Input
    if 'Input' in end_layers:
        raise NetworkValidationError(
            'Input can not have incoming connections!')

    # TODO: check if connected
    # TODO: check for cycles
    return True