コード例 #1
0
def setup_logging(parameters=None,
                  coloredlogs=False,
                  logging_file=None,
                  default_setup_file='logging.yaml',
                  default_level=logging.INFO,
                  environmental_variable='LOG_CFG'):
    """Setup logging configuration

    Parameters
    ----------
    parameters : dict
        Parameters in dict
        Default value None

    coloredlogs : bool
        Use coloredlogs
        Default value False

    logging_file : str
        Log filename for file based logging, if none given no file logging is used.
        Default value None

    environmental_variable : str
        Environmental variable to get the logging setup filename, if set will override default_setup_file
        Default value 'LOG_CFG'

    default_setup_file : str
        Default logging parameter file, used if one is not set in given ParameterContainer
        Default value 'logging.yaml'

    default_level : logging.level
        Default logging level, used if one is not set in given ParameterContainer
        Default value 'logging.INFO'

    Returns
    -------
    nothing

    """

    class LoggerFilter(object):
        def __init__(self, level):
            self.__level = level

        def filter(self, log_record):
            return log_record.levelno <= self.__level

    formatters = {
        'simple': "[%(levelname).1s] %(message)s",
        'normal': "%(asctime)s\t[%(name)-20s]\t[%(levelname)-8s]\t%(message)s",
        'extended': "[%(asctime)s] [%(name)s]\t [%(levelname)-8s]\t %(message)s \t(%(filename)s:%(lineno)s)",
        'extended2': "[%(levelname).1s] %(message)s \t(%(filename)s:%(lineno)s)",
        'file_extended': "[%(levelname).1s] [%(asctime)s] %(message)s",
    }

    if not parameters:
        logging_parameter_file = default_setup_file

        value = os.getenv(environmental_variable, None)
        if value:
            # If environmental variable set
            logging_parameter_file = value

        if os.path.exists(logging_parameter_file):
            with open(logging_parameter_file, 'rt') as f:
                config = yaml.safe_load(f.read())
            logging.config.dictConfig(config)

            try:
                # Check if coloredlogs is available
                import coloredlogs
                coloredlogs.install(
                    level=config['handlers']['console']['level'],
                    fmt=config['formatters'][config['handlers']['console']['formatter']]['format']
                )

            except ImportError:
                pass

        else:
            if coloredlogs:
                try:
                    # Check if coloredlogs is available
                    import coloredlogs

                    coloredlogs.install(
                        level=logging.INFO,
                        fmt=formatters['simple'],
                        reconfigure=True
                    )

                except ImportError:
                    logger = logging.getLogger()
                    logger.setLevel(default_level)

                    console_info = logging.StreamHandler()
                    console_info.setLevel(logging.INFO)
                    console_info.setFormatter(logging.Formatter(formatters['simple']))
                    console_info.addFilter(LoggerFilter(logging.INFO))
                    logger.addHandler(console_info)

                    console_debug = logging.StreamHandler()
                    console_debug.setLevel(logging.DEBUG)
                    console_debug.setFormatter(logging.Formatter(formatters['simple']))
                    console_debug.addFilter(LoggerFilter(logging.DEBUG))
                    logger.addHandler(console_debug)

                    console_warning = logging.StreamHandler()
                    console_warning.setLevel(logging.WARNING)
                    console_warning.setFormatter(logging.Formatter(formatters['simple']))
                    console_warning.addFilter(LoggerFilter(logging.WARNING))
                    logger.addHandler(console_warning)

                    console_critical = logging.StreamHandler()
                    console_critical.setLevel(logging.CRITICAL)
                    console_critical.setFormatter(logging.Formatter(formatters['extended2']))
                    console_critical.addFilter(LoggerFilter(logging.CRITICAL))
                    logger.addHandler(console_critical)

                    console_error = logging.StreamHandler()
                    console_error.setLevel(logging.ERROR)
                    console_error.setFormatter(logging.Formatter(formatters['extended2']))
                    console_error.addFilter(LoggerFilter(logging.ERROR))
                    logger.addHandler(console_error)

                    if logging_file:
                        file_info = logging.handlers.RotatingFileHandler(
                            filename=logging_file,
                            maxBytes=10485760,
                            backupCount=20,
                            encoding='utf8'
                        )
                        file_info.setLevel(logging.INFO)
                        file_info.setFormatter(logging.Formatter(formatters['file_extended']))
                        logger.addHandler(file_info)

            else:
                logger = logging.getLogger()
                logger.setLevel(default_level)

                console_info = logging.StreamHandler()
                console_info.setLevel(logging.INFO)
                console_info.setFormatter(logging.Formatter(formatters['simple']))
                console_info.addFilter(LoggerFilter(logging.INFO))
                logger.addHandler(console_info)

                console_debug = logging.StreamHandler()
                console_debug.setLevel(logging.DEBUG)
                console_debug.setFormatter(logging.Formatter(formatters['simple']))
                console_debug.addFilter(LoggerFilter(logging.DEBUG))
                logger.addHandler(console_debug)

                console_warning = logging.StreamHandler()
                console_warning.setLevel(logging.WARNING)
                console_warning.setFormatter(logging.Formatter(formatters['simple']))
                console_warning.addFilter(LoggerFilter(logging.WARNING))
                logger.addHandler(console_warning)

                console_critical = logging.StreamHandler()
                console_critical.setLevel(logging.CRITICAL)
                console_critical.setFormatter(logging.Formatter(formatters['extended2']))
                console_critical.addFilter(LoggerFilter(logging.CRITICAL))
                logger.addHandler(console_critical)

                console_error = logging.StreamHandler()
                console_error.setLevel(logging.ERROR)
                console_error.setFormatter(logging.Formatter(formatters['extended2']))
                console_error.addFilter(LoggerFilter(logging.ERROR))
                logger.addHandler(console_error)

                if logging_file:
                    file_info = logging.handlers.RotatingFileHandler(
                        filename=logging_file,
                        maxBytes=10485760,
                        backupCount=20,
                        encoding='utf8'
                    )
                    file_info.setLevel(logging.INFO)
                    file_info.setFormatter(logging.Formatter(formatters['file_extended']))
                    logger.addHandler(file_info)

    else:
        from dcase_util.containers import DictContainer
        parameters = DictContainer(parameters)
        logging.config.dictConfig(parameters.get('parameters'))
        if (parameters.get('colored', False) and
           'console' in parameters.get_path('parameters.handlers')):

            try:
                # Check if coloredlogs is available
                import coloredlogs
                coloredlogs.install(
                    level=parameters.get_path('parameters.handlers.console.level'),
                    fmt=parameters.get_path('parameters.formatters')[
                        parameters.get_path('parameters.handlers.console.formatter')
                    ].get('format')
                )
            except ImportError:
                pass

    # Function to handle uncaught expections
    def handle_exception(exc_type, exc_value, exc_traceback):
        if issubclass(exc_type, KeyboardInterrupt):
            sys.__excepthook__(exc_type, exc_value, exc_traceback)
            return

        logger.error('Uncaught exception', exc_info=(exc_type, exc_value, exc_traceback))

    sys.excepthook = handle_exception
コード例 #2
0
    def system_meta(self,
                    results,
                    task=None,
                    check_development_dataset=True,
                    check_evaluation_dataset=True):
        """Check system result scores given in the meta data

        Parameters
        ----------
        results : dict
            Result meta data

        task : str, optional
            Temporal override for the task parameter given to class constructor.

        check_development_dataset : bool
            Check development dataset results

        check_evaluation_dataset : bool
            Check evaluation dataset results

        Returns
        -------
        self

        """

        if task is None:
            task = self.task

        if results is None:
            self.error_log.append(u'No results section')

        else:
            results = DictContainer(results)
            if check_development_dataset:
                # Check development dataset results

                if results.get('development_dataset') is None:
                    self.error_log.append(u'No development results given')

                else:
                    if task == 'ASC':
                        if results.get_path(
                                'development_dataset.overall.accuracy'
                        ) is None:
                            self.error_log.append(
                                u'No development overall result given ')

                    elif task == 'SED_event':
                        if results.get_path(
                                'development_dataset.event_based.overall.er'
                        ) is None:
                            self.error_log.append(
                                u'No development overall result given [event_based.overall.er]'
                            )

                        if results.get_path(
                                'development_dataset.event_based.overall.f1'
                        ) is None:
                            self.error_log.append(
                                u'No development overall result given [event_based.overall.f1]'
                            )

                    elif task == 'SED_segment':
                        if results.get_path(
                                'development_dataset.segment_based.overall.er'
                        ) is None:
                            self.error_log.append(
                                u'No development overall result given [segment_based.overall.er]'
                            )

                        if results.get_path(
                                'development_dataset.segment_based.overall.f1'
                        ) is None:
                            self.error_log.append(
                                u'No development overall result given [segment_based.overall.f1]'
                            )

                    elif task == 'task4':
                        pass

                    # Check development dataset / class wise results
                    if task == 'ASC':
                        if results.get_path(
                                'development_dataset.class_wise') is None:
                            self.error_log.append(
                                u'No class_wise development results given')

                        else:
                            if len(
                                    results.get_path(
                                        'development_dataset.class_wise')
                            ) != len(self.class_labels):
                                self.error_log.append(
                                    u'Incorrect number class-wise development results given [{class_wise:d}/{target:d}]'
                                    .format(class_wise=len(
                                        results.get_path(
                                            'development_dataset.class_wise')),
                                            target=len(self.class_labels)))

                            for class_label, class_data in iteritems(
                                    results.get_path(
                                        'development_dataset.class_wise')):
                                if 'accuracy' not in class_data or class_data[
                                        'accuracy'] is None:
                                    self.error_log.append(
                                        u'Incorrect class-wise development results given for [{class_label:s}]'
                                        .format(class_label=class_label))

                    elif task == 'SED_event':
                        if results.get_path(
                                'development_dataset.event_based.class_wise'
                        ) is not None:
                            if len(
                                    results.get_path(
                                        'development_dataset.event_based.class_wise'
                                    )) != len(self.class_labels):
                                self.error_log.append(
                                    u'Incorrect number class-wise development results given [{class_wise:d}/{target:d}]'
                                    .format(class_wise=len(
                                        results.get_path(
                                            'development_dataset.event_based.class_wise'
                                        )),
                                            target=len(self.class_labels)))

                            for class_label, class_data in iteritems(
                                    results.get_path(
                                        'development_dataset.event_based.class_wise'
                                    )):
                                if class_data.get('er') is None:
                                    self.error_log.append(
                                        u'Incorrect class-wise development results given for [{class_label:s} / er]'
                                        .format(class_label=class_label))

                                if class_data.get('f1') is None:
                                    self.error_log.append(
                                        u'Incorrect class-wise development results given for [{class_label:s} / f1]'
                                        .format(class_label=class_label))

                        else:
                            self.error_log.append(
                                u'No class_wise development results given')

                    elif task == 'SED_segment':
                        if results.get_path(
                                'development_dataset.segment_based.class_wise'
                        ) is not None:
                            if len(
                                    results.get_path(
                                        'development_dataset.segment_based.class_wise'
                                    )) != len(self.class_labels):
                                self.error_log.append(
                                    u'Incorrect number class-wise development results given [{class_wise:d}/{target:d}]'
                                    .format(class_wise=len(
                                        results.get_path(
                                            'development_dataset.segment_based.class_wise'
                                        )),
                                            target=len(self.class_labels)))

                            for class_label, class_data in iteritems(
                                    results.get_path(
                                        'development_dataset.segment_based.class_wise'
                                    )):
                                if class_data.get('er') is None:
                                    self.error_log.append(
                                        u'Incorrect class-wise development results given for [{class_label:s} / er]'
                                        .format(class_label=class_label))

                                if class_data.get('f1') is None:
                                    self.error_log.append(
                                        u'Incorrect class-wise development results given for [{class_label:s} / f1]'
                                        .format(class_label=class_label))

                        else:
                            self.error_log.append(
                                u'No class_wise development results given')

                    elif task == 'task4':
                        pass

            if check_evaluation_dataset:
                # Check evaluation dataset results
                if 'evaluation_dataset' not in results:
                    self.error_log.append(u'No evaluation results given')

                else:
                    if task == 'ASC':
                        if results.get_path(
                                'evaluation_dataset.overall') is None:
                            self.error_log.append(
                                u'No evaluation results given')

                        if results.get_path(
                                'evaluation_dataset.class_wise') is not None:
                            if len(
                                    results.get_path(
                                        'evaluation_dataset.class_wise')
                            ) != len(self.class_labels):
                                self.error_log.append(
                                    u'Incorrect number class-wise evaluation results given [{class_wise:d}/{target:d}]'
                                    .format(class_wise=len(
                                        results.get_path(
                                            'evaluation_dataset.class_wise')),
                                            target=len(self.class_labels)))

                            for class_label, class_data in iteritems(
                                    results.get_path(
                                        'evaluation_dataset.class_wise')):
                                if class_data.get('accuracy') is None:
                                    self.error_log.append(
                                        u'Incorrect class-wise evaluation results given for [{class_label:s}]'
                                        .format(class_label=class_label))
                        else:
                            self.error_log.append(
                                u'No class_wise development results given')

                    elif task == 'SED_event':
                        if results.get_path(
                                'evaluation_dataset.event_based.overall.er'
                        ) is None:
                            self.error_log.append(
                                u'No evaluation results given [event_based.overall.er]'
                            )

                        if results.get_path(
                                'evaluation_dataset.event_based.overall.f1'
                        ) is None:
                            self.error_log.append(
                                u'No evaluation results given [event_based.overall.f1]'
                            )

                        if results.get_path(
                                'evaluation_dataset.event_based.class_wise'
                        ) is not None:
                            if len(
                                    results.get_path(
                                        'evaluation_dataset.event_based.class_wise'
                                    )) != len(self.class_labels):
                                self.error_log.append(
                                    u'Incorrect number class-wise evaluation results given [{class_wise:d}/{target:d}]'
                                    .format(class_wise=len(
                                        results.get_path(
                                            'evaluation_dataset.event_based.class_wise'
                                        )),
                                            target=len(self.class_labels)))

                            for class_label, class_data in iteritems(
                                    results.get_path(
                                        'evaluation_dataset.event_based.class_wise'
                                    )):
                                if class_data.get('er') is None:
                                    self.error_log.append(
                                        u'Incorrect class-wise evaluation results given for [{class_label:s} / er]'
                                        .format(class_label=class_label))

                                if class_data.get('f1') is None:
                                    self.error_log.append(
                                        u'Incorrect class-wise evaluation results given for [{class_label:s} / f1]'
                                        .format(class_label=class_label))

                        else:
                            self.error_log.append(
                                u'No class_wise evaluation results given')

                    elif task == 'SED_segment':
                        if results.get_path(
                                'evaluation_dataset.segment_based.overall.er'
                        ) is None:
                            self.error_log.append(
                                u'No evaluation results given [segment_based.overall.er]'
                            )

                        if results.get_path(
                                'evaluation_dataset.segment_based.overall.f1'
                        ) is None:
                            self.error_log.append(
                                u'No evaluation results given [segment_based.overall.f1]'
                            )

                        if results.get_path(
                                'evaluation_dataset.segment_based.class_wise'
                        ) is not None:
                            if len(
                                    results.get_path(
                                        'evaluation_dataset.segment_based.class_wise'
                                    )) != len(self.class_labels):
                                self.error_log.append(
                                    u'Incorrect number class-wise evaluation results given [{class_wise:d}/{target:d}]'
                                    .format(class_wise=len(
                                        results.get_path(
                                            'evaluation_dataset.segment_based.class_wise'
                                        )),
                                            target=len(self.class_labels)))

                            for class_label, class_data in iteritems(
                                    results.get_path(
                                        'evaluation_dataset.segment_based.class_wise'
                                    )):
                                if class_data.get('er') is None:
                                    self.error_log.append(
                                        u'Incorrect class-wise evaluation results given for [{class_label:s} / er]'
                                        .format(class_label=class_label))

                                if class_data.get('f1') is None:
                                    self.error_log.append(
                                        u'Incorrect class-wise evaluation results given for [{class_label:s} / f1]'
                                        .format(class_label=class_label))

                        else:
                            self.error_log.append(
                                u'No class_wise evaluation results given')

                    elif task == 'task4':
                        pass

        return self
コード例 #3
0
    def submission_authors(self,
                           authors,
                           check_email=True,
                           check_affiliation=True,
                           check_affiliation_abbreviation=True,
                           check_affiliation_department=True):
        """Check submission authors
        Parameters
        ----------
        authors : list of dict
            List of authors dicts.

        check_email : bool
            Check that author email exists.

        check_affiliation : bool
            Check author affiliation.

        check_affiliation_abbreviation : bool
            Check that affiliation abbreviation exists.

        check_affiliation_department : bool
            Check that affiliation has department defined.

        Returns
        -------
        self

        """

        if not isinstance(authors, list):
            self.error_log.append(
                u'Authors not given in list format for the submission')

        for author in authors:
            author = DictContainer(author)

            if author.get('lastname') is None:
                self.error_log.append(
                    u'No lastname given for author ({last_name:s}, {first_name:s})'
                    .format(last_name=author['lastname'],
                            first_name=author['firstname']))

            if author.get('firstname') is None:
                self.error_log.append(
                    u'No firstname given for author ({last_name:s}, {first_name:s})'
                    .format(last_name=author['lastname'],
                            first_name=author['firstname']))

            if check_email:
                if author.get('email') is None:
                    self.error_log.append(
                        u'No email given for author ({last_name:s}, {first_name:s})'
                        .format(last_name=author['lastname'],
                                first_name=author['firstname']))

            if check_affiliation:
                if author.get('affiliation') is None:
                    self.error_log.append(
                        u'No affiliation given for author ({last_name:s}, {first_name:s})'
                        .format(last_name=author['lastname'],
                                first_name=author['firstname']))

                else:
                    if isinstance(author.get('affiliation'), list):
                        for a in author.get('affiliation'):
                            affiliation = ', '.join(
                                filter(None, list(a.values())))
                            if check_affiliation_abbreviation:
                                if a.get('abbreviation') is None:
                                    self.error_log.append(
                                        u'No abbreviation given ({last_name:s}, {first_name:s}, {affiliation:s})'
                                        .format(last_name=author['lastname'],
                                                first_name=author['firstname'],
                                                affiliation=affiliation))

                            if check_affiliation_department:
                                if a.get('department') is None:
                                    self.error_log.append(
                                        u'No department given ({last_name:s}, {first_name:s}, {affiliation:s})'
                                        .format(last_name=author['lastname'],
                                                first_name=author['firstname'],
                                                affiliation=affiliation))

                            if a.get('institute') is None:
                                self.error_log.append(
                                    u'No institute given ({last_name:s}, {first_name:s}, {affiliation:s})'
                                    .format(last_name=author['lastname'],
                                            first_name=author['firstname'],
                                            affiliation=affiliation))

                            if a.get('location') is None:
                                self.error_log.append(
                                    u'No location given ({last_name:s}, {first_name:s}, {affiliation:s})'
                                    .format(last_name=author['lastname'],
                                            first_name=author['firstname'],
                                            affiliation=affiliation))

                    else:
                        affiliation = ', '.join(
                            filter(None, list(author['affiliation'].values())))
                        if check_affiliation_abbreviation:
                            if author.get_path(
                                    'affiliation.abbreviation') is None:
                                self.error_log.append(
                                    u'No abbreviation given ({last_name:s}, {first_name:s}, {affiliation:s})'
                                    .format(last_name=author['lastname'],
                                            first_name=author['firstname'],
                                            affiliation=affiliation))
                        if check_affiliation_department:
                            if author.get_path(
                                    'affiliation.department') is None:
                                self.error_log.append(
                                    u'No department given ({last_name:s}, {first_name:s}, {affiliation:s})'
                                    .format(last_name=author['lastname'],
                                            first_name=author['firstname'],
                                            affiliation=affiliation))
                        if author.get_path('affiliation.institute') is None:
                            self.error_log.append(
                                u'No institute given ({last_name:s}, {first_name:s}, {affiliation:s})'
                                .format(last_name=author['lastname'],
                                        first_name=author['firstname'],
                                        affiliation=affiliation))

                        if author.get_path('affiliation.location') is None:
                            self.error_log.append(
                                u'No location given ({last_name:s}, {first_name:s})'
                                .format(last_name=author['lastname'],
                                        first_name=author['firstname'],
                                        affiliation=affiliation))
        return self
コード例 #4
0
ファイル: model.py プロジェクト: stachu86/dcase_util
def model_summary_string(keras_model, mode='keras', show_parameters=True, display=False):
    """Model summary in a formatted string, similar to Keras model summary function.

    Parameters
    ----------
    keras_model : keras model
        Keras model

    mode : str
        Summary mode ['extended', 'keras']. In case 'keras', standard Keras summary is returned.
        Default value keras

    show_parameters : bool
        Show model parameter count and input / output shapes
        Default value True

    display : bool
        Display summary immediately, otherwise return string
        Default value False

    Returns
    -------
    str
        Model summary

    """

    if is_jupyter():
        ui = FancyHTMLStringifier()
        html_mode = True
    else:
        ui = FancyStringifier()
        html_mode = False

    output = ''
    output += ui.line('Model summary') + '\n'

    if mode == 'extended' or mode == 'extended_wide':
        layer_name_map = {
            'BatchNormalization': 'BatchNorm',
        }

        layer_type_html_tags = {
            'InputLayer': '<span class="label label-default">{0:s}</span>',
            'Dense': '<span class="label label-primary">{0:s}</span>',
            'TimeDistributed': '<span class="label label-primary">{0:s}</span>',

            'BatchNorm': '<span class="label label-default">{0:s}</span>',
            'Activation': '<span class="label label-default">{0:s}</span>',
            'Dropout': '<span class="label label-default">{0:s}</span>',

            'Flatten': '<span class="label label-success">{0:s}</span>',
            'Reshape': '<span class="label label-success">{0:s}</span>',
            'Permute': '<span class="label label-success">{0:s}</span>',

            'Conv1D': '<span class="label label-warning">{0:s}</span>',
            'Conv2D': '<span class="label label-warning">{0:s}</span>',

            'MaxPooling1D': '<span class="label label-success">{0:s}</span>',
            'MaxPooling2D': '<span class="label label-success">{0:s}</span>',
            'MaxPooling3D': '<span class="label label-success">{0:s}</span>',
            'AveragePooling1D': '<span class="label label-success">{0:s}</span>',
            'AveragePooling2D': '<span class="label label-success">{0:s}</span>',
            'AveragePooling3D': '<span class="label label-success">{0:s}</span>',
            'GlobalMaxPooling1D': '<span class="label label-success">{0:s}</span>',
            'GlobalMaxPooling2D': '<span class="label label-success">{0:s}</span>',
            'GlobalMaxPooling3D': '<span class="label label-success">{0:s}</span>',
            'GlobalAveragePooling1D': '<span class="label label-success">{0:s}</span>',
            'GlobalAveragePooling2D': '<span class="label label-success">{0:s}</span>',
            'GlobalAveragePooling3D': '<span class="label label-success">{0:s}</span>',

            'RNN': '<span class="label label-danger">{0:s}</span>',
            'SimpleRNN': '<span class="label label-danger">{0:s}</span>',
            'GRU': '<span class="label label-danger">{0:s}</span>',
            'CuDNNGRU': '<span class="label label-danger">{0:s}</span>',
            'LSTM': '<span class="label label-danger">{0:s}</span>',
            'CuDNNLSTM': '<span class="label label-danger">{0:s}</span>',
            'Bidirectional': '<span class="label label-danger">{0:s}</span>'
        }

        from tensorflow import keras
        from distutils.version import LooseVersion
        import tensorflow.keras.backend as keras_backend

        table_data = {
            'layer_type': [],
            'output': [],
            'parameter_count': [],
            'name': [],
            'connected_to': [],
            'activation': [],
            'initialization': []
        }

        row_separators = []
        prev_name = None
        for layer_id, layer in enumerate(keras_model.layers):
            connections = []
            if LooseVersion(keras.__version__) >= LooseVersion('2.1.3'):
                for node_index, node in enumerate(layer._inbound_nodes):
                    for i in range(len(node.inbound_layers)):
                        inbound_layer = node.inbound_layers[i].name
                        inbound_node_index = node.node_indices[i]
                        inbound_tensor_index = node.tensor_indices[i]
                        connections.append(
                            inbound_layer + '[' + str(inbound_node_index) + '][' + str(inbound_tensor_index) + ']'
                        )

            else:
                for node_index, node in enumerate(layer.inbound_nodes):
                    for i in range(len(node.inbound_layers)):
                        inbound_layer = node.inbound_layers[i].name
                        inbound_node_index = node.node_indices[i]
                        inbound_tensor_index = node.tensor_indices[i]
                        connections.append(
                            inbound_layer + '[' + str(inbound_node_index) + '][' + str(inbound_tensor_index) + ']'
                        )

            config = DictContainer(layer.get_config())
            layer_name = layer.__class__.__name__
            if layer_name in layer_name_map:
                layer_name = layer_name_map[layer_name]

            if html_mode and layer_name in layer_type_html_tags:
                layer_name = layer_type_html_tags[layer_name].format(layer_name)

            if config.get_path('kernel_initializer.class_name') == 'VarianceScaling':
                init = str(config.get_path('kernel_initializer.config.distribution', '---'))

            elif config.get_path('kernel_initializer.class_name') == 'RandomUniform':
                init = 'uniform'

            else:
                init = '-'

            name_parts = layer.name.split('_')
            if prev_name != name_parts[0]:
                row_separators.append(layer_id)
                prev_name = name_parts[0]

            table_data['layer_type'].append(layer_name)
            table_data['output'].append(str(layer.output_shape))
            table_data['parameter_count'].append(str(layer.count_params()))
            table_data['name'].append(layer.name)
            table_data['connected_to'].append(str(connections[0]) if len(connections) > 0 else '-')
            table_data['activation'].append(str(config.get('activation', '-')))
            table_data['initialization'].append(init)

        trainable_count = int(
            numpy.sum([keras_backend.count_params(p) for p in set(keras_model.trainable_weights)])
        )

        non_trainable_count = int(
            numpy.sum([keras_backend.count_params(p) for p in set(keras_model.non_trainable_weights)])
        )

        # Show row separators only if they are useful
        if len(row_separators) == len(keras_model.layers):
            row_separators = None
        if mode == 'extended':
            output += ui.table(
                cell_data=[table_data['name'], table_data['layer_type'], table_data['output'], table_data['parameter_count']],
                column_headers=['Layer name', 'Layer type', 'Output shape', 'Parameters'],
                column_types=['str30', 'str20', 'str25', 'str20'],
                column_separators=[1, 2],
                row_separators=row_separators,
                indent=4
            )

        elif mode == 'extended_wide':
            output += ui.table(
                cell_data=[table_data['name'], table_data['layer_type'], table_data['output'], table_data['parameter_count'],
                           table_data['activation'], table_data['initialization']],
                column_headers=['Layer name', 'Layer type', 'Output shape', 'Parameters', 'Act.', 'Init.'],
                column_types=['str30', 'str20', 'str25', 'str20', 'str15', 'str15'],
                column_separators=[1, 2, 3],
                row_separators=row_separators,
                indent=4
            )

        if show_parameters:
            output += ui.line('') + '\n'
            output += ui.line('Parameters', indent=4) + '\n'
            output += ui.data(indent=6, field='Total', value=trainable_count + non_trainable_count) + '\n'
            output += ui.data(indent=6, field='Trainable', value=trainable_count) + '\n'
            output += ui.data(indent=6, field='Non-Trainable', value=non_trainable_count) + '\n'

    else:
        output_buffer = []
        keras_model.summary(print_fn=output_buffer.append)
        for line in output_buffer:
            if is_jupyter():
                output += ui.line('<code>'+line+'</code>', indent=4) + '\n'
            else:
                output += ui.line(line, indent=4) + '\n'

    model_config = keras_model.get_config()

    if show_parameters:
        output += ui.line('') + '\n'
        output += ui.line('Input', indent=4) + '\n'
        output += ui.data(indent=6, field='Shape', value=keras_model.input_shape) + '\n'

        output += ui.line('Output', indent=4) + '\n'
        output += ui.data(indent=6, field='Shape', value=keras_model.output_shape) + '\n'

        if isinstance(model_config, dict) and 'layers' in model_config:
            output += ui.data(
                indent=6,
                field='Activation',
                value=model_config['layers'][-1]['config'].get('activation')
            ) + '\n'

        elif isinstance(model_config, list):
            output += ui.data(
                indent=6,
                field='Activation',
                value=model_config[-1].get('config', {}).get('activation')
            ) + '\n'

    if display:
        if is_jupyter():
            from IPython.core.display import display, HTML
            display(HTML(output))

        else:
            print(output)

    else:
        return output
コード例 #5
0
ファイル: model.py プロジェクト: stachu86/dcase_util
def create_sequential_model(model_parameter_list, input_shape=None, output_shape=None, constants=None, return_functional=False):
    """Create sequential Keras model

    Example parameters::

        model_parameter_list = [
            {
                'class_name': 'Dense',
                'config': {
                    'units': 'CONSTANT_B',
                    'kernel_initializer': 'uniform',
                    'activation': 'relu'
                }
            },
            {
                'class_name': 'Dropout',
                'config': {
                    'rate': 0.2
                }
            },
            {
                'class_name': 'Dense',
                'config': {
                    'units': 'CONSTANT_A' * 2,
                    'kernel_initializer': 'uniform',
                    'activation': 'relu'
                }
            },
            {
                'class_name': 'Dropout',
                'config': {
                    'rate': 0.2
                }
            },
            {
                'class_name': 'Dense',
                'config': {
                    'units': 'CLASS_COUNT',
                    'kernel_initializer': 'uniform',
                    'activation': 'softmax'
                }
            }
        ]
        constants = {
            'CONSTANT_A': 50,
            'CONSTANT_B': 100
        }

    Parameters
    ----------
    model_parameter_list : dict or DictContainer
        Model parameters

    input_shape : int
        Size of the input layer
        Default value None

    output_shape : int
        Size of the output layer
        Default value None

    constants : dict or DictContainer
        Constants used in the model_parameter definitions.
        Default value None

    return_functional : bool
        Convert sequential model into function model.
        Default value False

    Returns
    -------
    Keras model

    """

    from tensorflow.keras.models import Sequential
    keras_model = Sequential()

    tuple_fields = [
        'input_shape',
        'kernel_size',
        'pool_size',
        'dims',
        'target_shape',
        'strides'
    ]

    # Get constants for model
    if constants is None:
        constants = {}

    if 'INPUT_SHAPE' not in constants and input_shape is not None:
        constants['INPUT_SHAPE'] = input_shape

    if 'OUTPUT_SHAPE' not in constants and output_shape is not None:
        constants['OUTPUT_SHAPE'] = output_shape

    if 'CLASS_COUNT' not in constants:
        constants['CLASS_COUNT'] = output_shape

    if 'FEATURE_VECTOR_LENGTH' not in constants:
        constants['FEATURE_VECTOR_LENGTH'] = input_shape

    def logger():
        logger_instance = logging.getLogger(__name__)
        if not logger_instance.handlers:
            setup_logging()
        return logger_instance

    def process_field(value, constants_dict):
        math_eval = SimpleMathStringEvaluator()

        if isinstance(value, str):
            sub_fields = value.split()
            if len(sub_fields) > 1:
                # Inject constants to math formula
                for subfield_id, subfield in enumerate(sub_fields):
                    if subfield in constants_dict:
                        sub_fields[subfield_id] = str(constants_dict[subfield])
                value = ''.join(sub_fields)

            else:
                # Inject constants
                if value in constants_dict:
                    value = str(constants_dict[value])

            return math_eval.eval(value)

        elif isinstance(value, list):
            processed_value_list = []
            for item_id, item in enumerate(value):
                processed_value_list.append(
                    process_field(
                        value=item,
                        constants_dict=constants_dict
                    )
                )

            return processed_value_list

        else:
            return value

    # Inject constant into constants with equations
    for field in list(constants.keys()):
        constants[field] = process_field(
            value=constants[field],
            constants_dict=constants
        )

    # Setup layers
    for layer_id, layer_setup in enumerate(model_parameter_list):
        # Get layer parameters
        layer_setup = DictContainer(layer_setup)

        if 'config' not in layer_setup:
            layer_setup['config'] = {}

        # Get layer class
        try:
            layer_class = getattr(
                importlib.import_module('tensorflow.keras.layers'),
                layer_setup['class_name']
            )

        except AttributeError:
            message = 'Invalid Keras layer type [{type}].'.format(
                type=layer_setup['class_name']
            )
            logger().exception(message)
            raise AttributeError(message)

        # Inject constants
        for config_field in list(layer_setup['config'].keys()):
            layer_setup['config'][config_field] = process_field(
                value=layer_setup['config'][config_field],
                constants_dict=constants
            )

        # Convert lists into tuples
        for field in tuple_fields:
            if field in layer_setup['config'] and isinstance(layer_setup['config'][field], list):
                layer_setup['config'][field] = tuple(layer_setup['config'][field])

        # Inject input shape for Input layer if not given
        if layer_id == 0 and layer_setup.get_path('config.input_shape') is None and input_shape is not None:
            # Set input layer dimension for the first layer if not set
            layer_setup['config']['input_shape'] = (input_shape,)

        if 'wrapper' in layer_setup:
            # Get layer wrapper class
            try:
                wrapper_class = getattr(
                    importlib.import_module("tensorflow.keras.layers"),
                    layer_setup['wrapper']
                )

            except AttributeError:
                message = 'Invalid Keras layer wrapper type [{type}].'.format(
                    type=layer_setup['wrapper']
                )
                logger().exception(message)
                raise AttributeError(message)
            wrapper_parameters = layer_setup.get('config_wrapper', {})

            if layer_setup.get('config'):
                keras_model.add(
                    wrapper_class(layer_class(**dict(layer_setup.get('config'))), **dict(wrapper_parameters)))
            else:
                keras_model.add(wrapper_class(layer_class(), **dict(wrapper_parameters)))
        else:
            if layer_setup.get('config'):
                keras_model.add(layer_class(**dict(layer_setup.get('config'))))
            else:
                keras_model.add(layer_class())


    if return_functional:
        from tensorflow.keras.layers import Input
        from tensorflow.keras.models import Model
        input_layer = Input(batch_shape=keras_model.layers[0].input_shape)
        prev_layer = input_layer
        for layer in keras_model.layers:
            prev_layer = layer(prev_layer)

        keras_model = Model(
            inputs=[input_layer],
            outputs=[prev_layer]
        )

    return keras_model
コード例 #6
0
ファイル: model.py プロジェクト: mic2zar/dcase_util
def model_summary_string(keras_model, mode='keras'):
    """Model summary in a formatted string, similar to Keras model summary function.

    Parameters
    ----------
    keras_model : keras model
        Keras model

    mode : str
        Summary mode ['extended', 'keras']. In case 'keras', standard Keras summary is returned.
        Default value keras

    Returns
    -------
    str
        Model summary

    """

    ui = FancyStringifier()
    output = ''
    output += ui.line('Model summary') + '\n'

    if mode == 'extended':
        layer_name_map = {
            'BatchNormalization': 'BatchNorm',
        }
        import keras
        from distutils.version import LooseVersion
        import keras.backend as keras_backend

        output += ui.row('Layer type',
                         'Output',
                         'Param',
                         'Name',
                         'Connected to',
                         'Activ.',
                         'Init',
                         widths=[15, 25, 10, 20, 25, 10, 10],
                         indent=4) + '\n'
        output += ui.row('-', '-', '-', '-', '-', '-', '-') + '\n'

        for layer in keras_model.layers:
            connections = []
            if LooseVersion(keras.__version__) >= LooseVersion('2.1.3'):
                for node_index, node in enumerate(layer._inbound_nodes):
                    for i in range(len(node.inbound_layers)):
                        inbound_layer = node.inbound_layers[i].name
                        inbound_node_index = node.node_indices[i]
                        inbound_tensor_index = node.tensor_indices[i]
                        connections.append(inbound_layer + '[' +
                                           str(inbound_node_index) + '][' +
                                           str(inbound_tensor_index) + ']')

            else:
                for node_index, node in enumerate(layer.inbound_nodes):
                    for i in range(len(node.inbound_layers)):
                        inbound_layer = node.inbound_layers[i].name
                        inbound_node_index = node.node_indices[i]
                        inbound_tensor_index = node.tensor_indices[i]
                        connections.append(inbound_layer + '[' +
                                           str(inbound_node_index) + '][' +
                                           str(inbound_tensor_index) + ']')

            config = DictContainer(layer.get_config())
            layer_name = layer.__class__.__name__
            if layer_name in layer_name_map:
                layer_name = layer_name_map[layer_name]

            if config.get_path(
                    'kernel_initializer.class_name') == 'VarianceScaling':
                init = str(
                    config.get_path('kernel_initializer.config.distribution',
                                    '---'))

            elif config.get_path(
                    'kernel_initializer.class_name') == 'RandomUniform':
                init = 'uniform'

            else:
                init = '---'

            output += ui.row(
                layer_name, str(layer.output_shape), str(layer.count_params()),
                str(layer.name),
                str(connections[0]) if len(connections) > 0 else '---',
                str(config.get('activation', '---')), init) + '\n'

        trainable_count = int(
            numpy.sum([
                keras_backend.count_params(p)
                for p in set(keras_model.trainable_weights)
            ]))

        non_trainable_count = int(
            numpy.sum([
                keras_backend.count_params(p)
                for p in set(keras_model.non_trainable_weights)
            ]))

        output += ui.line('') + '\n'
        output += ui.line(
            'Parameters',
            indent=4,
        ) + '\n'
        output += ui.data(indent=6,
                          field='Total',
                          value=trainable_count + non_trainable_count) + '\n'
        output += ui.data(indent=6, field='Trainable',
                          value=trainable_count) + '\n'
        output += ui.data(
            indent=6, field='Non-Trainable', value=non_trainable_count) + '\n'

    else:
        output_buffer = []
        keras_model.summary(print_fn=output_buffer.append)
        for line in output_buffer:
            output += ui.line(line, indent=4) + '\n'

    output += ui.line('') + '\n'

    output += ui.data(
        indent=4, field='Input shape', value=keras_model.input_shape) + '\n'
    output += ui.data(
        indent=4, field='Output shape', value=keras_model.output_shape) + '\n'

    return output