Beispiel #1
0
def get_default_value(types: TypeCollection, port_data):
    """Return the default value for a particular port

    If a return value is not specified (or is None), return the default value for the data type"""
    specified_return_value = port_data['default_value']
    if specified_return_value is None:
        return types.render_value(port_data['data_type'],
                                  types.default_value(port_data['data_type']))
    else:
        return specified_return_value
Beispiel #2
0
    def __init__(self, project_config_file):
        self._project_config_file = project_config_file
        self._plugins = {}
        self._defined_types = {}
        self._project_config = {}
        self._components = {}
        self._types = TypeCollection()
        self._port_types = {}
        self._port_impl_lookup = {}
        self._signal_types = {}
        self._functions = {}

        self._ports = {}

        self._print_warnings = ['unconnected_signals']
Beispiel #3
0
def render_alias_typedef(type_collection: TypeCollection, type_name):
    context = {
        'template': "typedef {{ aliased }} {{ type_name }};",
        'data': {
            'type_name': type_name,
            'aliased': type_collection.resolve(type_name)
        }
    }

    return chevron.render(**context)
Beispiel #4
0
def impl_data_lookup(types: TypeCollection, port_data):
    port_type = port_data['port_type']
    if port_type not in port_type_data:
        return None

    implementation_data = port_type_data[port_type]['default_impl']
    pass_by = types.passed_by(port_data['data_type'])
    # noinspection PyCallingNonCallable
    return {
        **implementation_data['common'](types, port_data),
        **implementation_data[pass_by](types, port_data)
    }
Beispiel #5
0
def lookup_member(types: TypeCollection, data_type, member_list):
    if not member_list:
        return data_type

    type_data = types.get(data_type)

    if type_data['type'] == TypeCollection.STRUCT:
        return lookup_member(types, type_data['fields'][member_list[0]],
                             member_list[1:])
    elif type_data['type'] == TypeCollection.UNION:
        return lookup_member(types, type_data['members'][member_list[0]],
                             member_list[1:])
    else:
        raise Exception(
            'Trying to access member of non-struct type {}'.format(data_type))
Beispiel #6
0
def struct_formatter(types: TypeCollection, type_name, type_data, struct_value,
                     context):
    if type(struct_value) is str:
        return struct_value

    values = [
        '.{} = {}'.format(
            field_name,
            types.render_value(field_type, struct_value.get(field_name, None),
                               'initialization'))
        for field_name, field_type in type_data['fields'].items()
    ]

    if context == 'initialization':
        return '{{ {} }}'.format(', '.join(values))
    else:
        return '({}) {{ {} }}'.format(type_name, ', '.join(values))
Beispiel #7
0
def union_formatter(types: TypeCollection, type_name, type_data, union_value,
                    context):
    if type(union_value) is str:
        return union_value

    if len(union_value) != 1:
        raise Exception('Only a single union member can be assigned')

    values = [
        '.{} = {}'.format(
            name,
            types.render_value(type_data['members'][name], value,
                               'initialization'))
        for name, value in union_value.items()
    ]

    if context == 'initialization':
        return '{{ {} }}'.format(', '.join(values))
    else:
        return '({}) {{ {} }}'.format(type_name, ', '.join(values))
Beispiel #8
0
class Runtime:
    def __init__(self, project_config_file):
        self._project_config_file = project_config_file
        self._plugins = {}
        self._defined_types = {}
        self._project_config = {}
        self._components = {}
        self._types = TypeCollection()
        self._port_types = {}
        self._port_impl_lookup = {}
        self._signal_types = {}
        self._functions = {}

        self._ports = {}

        self._print_warnings = ['unconnected_signals']

    def add_plugin(self, plugin: RuntimePlugin):
        self._plugins[plugin.name] = plugin
        plugin.bind(self)

    def load(self, load_components=True):
        self.raise_event('init')

        with open(self._project_config_file, "r") as file:
            project_config = json.load(file)

        self.raise_event('load_project_config', project_config)

        if 'settings' not in project_config:
            project_config['settings'] = {
                'name': 'Project Name',
                'components_folder': 'components',
                'required_plugins': []
            }

        print('Loaded configuration for {}'.format(
            project_config['settings']['name']))

        self._project_config = project_config

        for plugin_name in self.settings['required_plugins']:
            if plugin_name not in self._plugins:
                raise Exception(
                    'Project requires {} plugin, which is not loaded'.format(
                        plugin_name))

        if load_components:
            for component_name in project_config['components']:
                self.load_component_config(component_name)

        self.raise_event('project_config_loaded', project_config)

    def add_port_type(self, port_type_name, data, lookup):
        self._port_types[port_type_name] = data
        self._port_impl_lookup[port_type_name] = lookup

    def process_port_def(self, component_name, port_name, port):
        port_type = port['port_type']

        try:
            attributes = self._port_types[port_type]['def_attributes']
            del port['port_type']
            return {
                'port_type': port_type,
                **attributes['static'],
                **copy(port, attributes['required'], attributes['optional'])
            }

        except KeyError:
            return port

        except Exception as e:
            raise Exception(
                'Port {}/{} ({}) has unexpected attribute set: {}'.format(
                    component_name, port_name, port_type, e))

    def load_component_config(self, component_name):
        if not self._project_config:
            self.load(False)

        component_config_file = '{}/{}/config.json'.format(
            self.settings['components_folder'], component_name)
        with open(component_config_file, "r") as file:
            component_config = json.load(file)
            self.add_component(component_name, component_config)

    def add_component(self, component_name, component_config):
        self.raise_event('load_component_config', component_name,
                         component_config)
        self._components[component_name] = component_config

        if not component_config['ports']:
            print('Warning: {} has no ports'.format(component_name))

        for port_name, port_data in component_config['ports'].items():
            processed_port = self.process_port_def(component_name, port_name,
                                                   port_data)
            component_config['ports'][port_name] = processed_port

            short_name = '{}/{}'.format(component_name, port_name)
            self._ports[short_name] = processed_port

    def _normalize_type_name(self, type_name):

        try:
            self._types.get(type_name)
        except KeyError:
            type_name = type_name.replace('const ',
                                          '').replace('*',
                                                      '').replace(' ', '')

        return type_name

    def _get_type_includes(self, type_name):
        if type(type_name) is list:
            includes = []
            for tn in type_name:
                inc = self._get_type_includes(tn)
                if type(inc) is list:
                    for i in inc:
                        if i not in includes:
                            includes.append(i)
                elif inc:
                    if inc not in includes:
                        includes.append(inc)
            return includes

        else:
            type_name = self._normalize_type_name(type_name)

            resolved_type_data = self._types[type_name]
            if resolved_type_data['type'] == TypeCollection.EXTERNAL_DEF:
                return resolved_type_data['defined_in']
            else:
                return None

    def _collect_type_dependencies(self, type_name):
        defs = []

        type_name = self._normalize_type_name(type_name)
        type_data = self._types.get(type_name)

        if type_data['type'] == TypeCollection.ALIAS:
            defs.append(type_data['aliases'])
        elif type_data['type'] == TypeCollection.EXTERNAL_DEF:
            pass
        elif type_data['type'] == TypeCollection.STRUCT:
            for field in type_data['fields'].values():
                for tn in self._collect_type_dependencies(field):
                    res = self._collect_type_dependencies(tn)
                    if type(res) is list:
                        defs += res
                    else:
                        defs.append(res)

        elif type_data['type'] == TypeCollection.UNION:
            for member in type_data['members'].values():
                for tn in self._collect_type_dependencies(member):
                    res = self._collect_type_dependencies(tn)
                    if type(res) is list:
                        defs += res
                    else:
                        defs.append(res)

        return defs

    def _sort_types_by_dependency(self, type_names, visited_types=None):
        if visited_types is None:
            visited_types = []

        if type(type_names) is not list:
            type_names = [type_names]

        types = []

        for t in type_names:
            if t in visited_types:
                continue
            else:
                visited_types.append(t)
                deps = self._collect_type_dependencies(t)

                for d in deps:
                    types += self._sort_types_by_dependency(d, visited_types)

                types.append(t)

        return types

    def update_component(self, component_name):

        component_folder = os.path.join(self.settings['components_folder'],
                                        component_name)
        source_file = os.path.join(component_folder, component_name + '.c')
        header_file = os.path.join(component_folder, component_name + '.h')
        config_file = os.path.join(component_folder, 'config.json')

        context = {
            'runtime': self,
            'component_folder': component_folder,
            'functions': {},
            'declarations': [],
            'files': {
                config_file: '',
                source_file: '',
                header_file: ''
            },
            'folders': [component_name]
        }
        self.raise_event('create_component_ports', component_name,
                         self._components[component_name], context)

        self.raise_event('before_generating_component', component_name,
                         context)

        funcs = context['functions'].values()
        function_headers = [fn.get_header() for fn in funcs]
        functions = [fn.get_function() for fn in funcs]

        includes = {'"{}.h"'.format(component_name), '"utils.h"'}
        for f in funcs:
            includes.update(f.includes)

        defined_type_names = self._components[component_name].get('types',
                                                                  {}).keys()

        sorted_types = self._sort_types_by_dependency(defined_type_names)

        type_includes = self._get_type_includes(sorted_types)
        typedefs = [self._types.generate_typedef(t) for t in sorted_types]

        ctx = {
            'includes': list_to_chevron_list(sorted(includes), 'header'),
            'component_name': component_name,
            'guard_def': to_underscore(component_name).upper(),
            'variables': context['declarations'],
            'types': typedefs,
            'type_includes': list_to_chevron_list(sorted(type_includes),
                                                  'header'),
            'functions': functions,
            'function_headers': function_headers
        }

        context['files'][config_file] = self.dump_component_config(
            component_name)
        context['files'][source_file] = chevron.render(source_template, ctx)
        context['files'][header_file] = chevron.render(
            component_header_template, ctx)

        self.raise_event('generating_component', component_name, context)

        return context['files']

    def add_signal_type(self, name, signal_type: SignalType):
        self._signal_types[name] = signal_type

    def create_function_for_port(self,
                                 component_name,
                                 port_name,
                                 function_data=None):
        if not function_data:
            short_name = "{}/{}".format(component_name, port_name)
            function_data = self._get_function_data(short_name)

        fn_name = function_data['func_name_pattern'].format(
            component_name, port_name)
        function = FunctionDescriptor.create(self._types, fn_name,
                                             function_data)
        function.add_input_assert(function_data.get('asserts', []))

        return function

    def get_port(self, short_name):
        return self._ports[short_name]

    def _get_function_data(self, short_name):
        port_data = self.get_port(short_name)
        port_type = port_data['port_type']

        return self._port_impl_lookup[port_type](self._types, port_data)

    def get_port_type_data(self, short_name):
        port_data = self.get_port(short_name)
        port_type = port_data['port_type']

        return self._port_types[port_type]

    def generate_runtime(self, filename, context=None):
        source_file_name = filename + '.c'
        header_file_name = filename + '.h'

        default_context = {
            'runtime': self,
            'files': {
                source_file_name: '',
                header_file_name: ''
            },
            'functions': {},
            'declarations': [],
            'exported_function_declarations': [],
            'signals': {}
        }

        if context is None:
            context = default_context
        else:
            context = {**default_context, **context}

        for connection in self._project_config['runtime']['port_connections']:
            provider_ref = connection['provider']

            provider_short_name = provider_ref['short_name']

            provider_port_type_data = self.get_port_type_data(
                provider_short_name)
            provided_signal_types = provider_port_type_data['provides']

            def create_if_weak(port_ref):
                """Generate function for the given port - but only if the default implementation is weak"""
                function_data = self._get_function_data(port_ref['short_name'])
                if 'weak' in function_data.get('attributes', []):
                    return self.create_function_for_port(
                        port_ref['component'], port_ref['port'], function_data)
                else:
                    return None

            def create_signal_connection(attributes, signal_name, signal_type,
                                         consumer_attributes):
                signal = signal_type.create_connection(context, signal_name,
                                                       provider_short_name,
                                                       attributes)
                signal.add_consumer(consumer_short_name, consumer_attributes)
                return signal

            if provider_short_name not in context['functions']:
                context['functions'][provider_short_name] = create_if_weak(
                    provider_ref)

            provider_attributes = {
                key: connection[key]
                for key in connection
                if key not in ['provider', 'consumer', 'consumers']
            }

            # create a dict to store providers signals
            if provider_short_name not in context['signals']:
                context['signals'][provider_short_name] = {}
            provider_signals = context['signals'][provider_short_name]

            for consumer_ref in connection['consumers']:
                consumer_short_name = consumer_ref['short_name']

                consumer_port_type_data = self.get_port_type_data(
                    consumer_short_name)
                consumed_signal_types = consumer_port_type_data['consumes']
                inferred_signal_type = provided_signal_types.intersection(
                    consumed_signal_types)

                if len(inferred_signal_type) == 0:
                    raise Exception('Incompatible ports: {} and {}'.format(
                        provider_short_name, consumer_short_name))
                elif len(inferred_signal_type) > 1:
                    raise Exception(
                        'Connection type can not be inferred for {} and {}'.
                        format(provider_short_name, consumer_short_name))

                signal_type_name = inferred_signal_type.pop()
                signal_type = self._signal_types[signal_type_name]

                if consumer_short_name not in context['functions']:
                    context['functions'][consumer_short_name] = create_if_weak(
                        consumer_ref)
                else:
                    # this port already is the consumer of some signal
                    # some ports can consume multiple signals, this is set in the port data
                    # (e.g. a runnable can be called by multiple events or calls)
                    if consumer_port_type_data['consumes'][
                            signal_type_name] == 'single':
                        raise Exception(
                            '{} cannot consume multiple signals'.format(
                                consumer_short_name))

                signal_name = '{}_{}' \
                    .format(provider_short_name, signal_type_name) \
                    .replace('/', '_')

                consumer_attributes = consumer_ref.get('attributes', {})

                try:
                    signals_of_current_type = provider_signals[
                        signal_type_name]
                    if type(signals_of_current_type) is list:
                        if signal_type.consumers == 'multiple_signals':
                            # create new signal in all cases
                            signal_name += str(len(signals_of_current_type))

                            new_signal = create_signal_connection(
                                provider_attributes, signal_name, signal_type,
                                consumer_attributes)

                            signals_of_current_type.append(new_signal)
                        else:
                            signals_of_current_type.add_consumer(
                                consumer_short_name, consumer_attributes)
                    elif signal_type.consumers == 'multiple':
                        signals_of_current_type.add_consumer(
                            consumer_short_name, consumer_attributes)
                    else:
                        raise Exception(
                            'Multiple consumers not allowed for {} signal (provided by {})'
                            .format(signal_type_name, provider_short_name))
                except KeyError:
                    new_signal = create_signal_connection(
                        provider_attributes, signal_name, signal_type,
                        consumer_attributes)

                    if signal_type.consumers == 'multiple_signals':
                        provider_signals[signal_type_name] = [new_signal]
                    else:
                        provider_signals[signal_type_name] = new_signal

        if 'unconnected_signals' in self._print_warnings:
            all_unconnected = set(self._ports.keys()).difference(
                context['functions'])
            for unconnected in all_unconnected:
                print('Warning: {} port is not connected'.format(unconnected))

        for signals in context['signals'].values():
            for signal in signals.values():
                if type(signal) is list:
                    for s in signal:
                        s.generate()
                else:
                    signal.generate()

        self.raise_event('before_generating_runtime', context)

        type_names = []
        for c in self._components.values():
            type_names += c.get('types', {}).keys()

        output_filename = filename[filename.rfind('/') + 1:]
        includes = {'"{}.h"'.format(output_filename), '"utils.h"'}

        for f in context['functions'].values():
            if f:
                type_names += f.referenced_types()
                includes.update(f.includes)

        sorted_types = self._sort_types_by_dependency(type_names)

        type_includes = self._get_type_includes(sorted_types)
        typedefs = [self._types.generate_typedef(t) for t in sorted_types]

        template_data = {
            'output_filename':
            output_filename,
            'includes':
            list_to_chevron_list(sorted(includes), 'header'),
            'components': [{
                'name': name,
                'guard_def': to_underscore(name).upper()
            } for name in self._components if name != 'Runtime'],  # TODO
            'types':
            typedefs,
            'type_includes':
            list_to_chevron_list(sorted(type_includes), 'header'),
            'function_declarations': [
                context['functions'][func_name].get_header()
                for func_name in context['exported_function_declarations']
            ],
            'functions': [
                func.get_function() for func in context['functions'].values()
                if func
            ],
            'variables':
            context['declarations']
        }

        context['files'][source_file_name] = chevron.render(
            source_template, template_data)
        context['files'][header_file_name] = chevron.render(
            runtime_header_template, template_data)

        self.raise_event('after_generating_runtime', context)

        return context['files']

    def raise_event(self, event_name, *args):
        for plugin in self._plugins:
            try:
                self._plugins[plugin].handle(event_name, args)
            except Exception:
                print('Error while processing {}/{}'.format(
                    plugin, event_name))
                raise

    @property
    def functions(self):
        return self._functions

    @property
    def types(self):
        return self._types

    @property
    def port_types(self):
        return self._port_types

    @property
    def settings(self):
        return self._project_config['settings']

    def dump_component_config(self, component_name):
        config = self._components[component_name].copy()
        self.raise_event('save_component_config', config)
        return json.dumps(config, indent=4)

    def dump_project_config(self):
        config = self._project_config.copy()
        self.raise_event('save_project_config', config)
        return json.dumps(config, indent=4)
Beispiel #9
0
def get_default_value_formatted(types: TypeCollection, port_data):
    default_value = types.default_value(port_data['data_type'])
    return types.render_value(port_data['data_type'], default_value)