示例#1
0
class ProgressLoggerCallback(BaseCallback):
    """Keras callback to show metrics in logging interface. Implements Keras Callback API.

    This callback is very similar to standard ``ProgbarLogger`` Keras callback, however it adds support for logging
    interface, and external metrics (metrics calculated outside Keras training process).

    """

    def __init__(self,
                 manual_update=False, epochs=None, external_metric_labels=None,
                 metric=None, loss=None, manual_update_interval=1, output_type='logging', show_timing=True,
                 **kwargs):
        """Constructor

        Parameters
        ----------
        epochs : int
            Total amount of epochs
            Default value None

        metric : str
            Metric name
            Default value None

        manual_update : bool
            Manually update callback, use this to when injecting external metrics
            Default value False

        manual_update_interval : int
            Epoch interval for manual update, used anticipate updates
            Default value 1

        output_type : str
            Output type, either 'logging', 'console', or 'notebook'
            Default value 'logging'

        show_timing : bool
            Show per epoch time and estimated time remaining
            Default value True

        external_metric_labels : dict or OrderedDict
            Dictionary with {'metric_label': 'metric_name'}
            Default value None

        """

        kwargs.update({
            'manual_update': manual_update,
            'epochs': epochs,
            'external_metric_labels': external_metric_labels,
        })

        super(ProgressLoggerCallback, self).__init__(**kwargs)

        if isinstance(metric, str):
            self.metric = metric

        elif callable(metric):
            self.metric = metric.__name__

        self.loss = loss

        self.manual_update_interval = manual_update_interval

        self.output_type = output_type

        self.show_timing = show_timing

        self.timer = Timer()

        self.ui = FancyStringifier()

        if self.output_type == 'logging':
            self.output_target = FancyLogger()

        elif self.output_type == 'console':
            self.output_target = FancyPrinter()

        elif self.output_type == 'notebook':
            self.output_target = FancyHTMLPrinter()
            self.ui = FancyHTMLStringifier()

        self.seen = 0
        self.log_values = []

        self.most_recent_values = collections.OrderedDict()
        self.most_recent_values['l_tra'] = None
        self.most_recent_values['l_val'] = None
        self.most_recent_values['m_tra'] = None
        self.most_recent_values['m_val'] = None

        self.data = {
            'l_tra': numpy.empty((self.epochs,)),
            'l_val': numpy.empty((self.epochs,)),
            'm_tra': numpy.empty((self.epochs,)),
            'm_val': numpy.empty((self.epochs,)),
        }
        self.data['l_tra'][:] = numpy.nan
        self.data['l_val'][:] = numpy.nan
        self.data['m_tra'][:] = numpy.nan
        self.data['m_val'][:] = numpy.nan

        for metric_label in self.external_metric_labels:
            self.data[metric_label] = numpy.empty((self.epochs,))
            self.data[metric_label][:] = numpy.nan

        self.header_shown = False
        self.last_update_epoch = 0

        self.target = None
        self.first_epoch = None
        self.total_time = 0

    def on_train_begin(self, logs=None):
        if self.epochs is None:
            self.epochs = self.params['epochs']

        if not self.header_shown:
            output = ''
            output += self.ui.line('Training') + '\n'

            if self.external_metric_labels:
                output += self.ui.row(
                    '', 'Loss', 'Metric', 'Ext. metrics', '',
                    widths=[10, 26, 26, len(self.external_metric_labels)*15, 17]
                ) + '\n'

                header2 = ['', self.loss, self.metric]
                header3 = ['Epoch', 'Train', 'Val', 'Train', 'Val']
                widths = [10, 13, 13, 13, 13]
                for metric_label, metric_name in iteritems(self.external_metric_labels):
                    header2.append('')
                    header3.append(metric_name)
                    widths.append(15)

                if self.show_timing:
                    header2.append('')
                    header3.append('Step time')
                    widths.append(20)

                    header2.append('')
                    header3.append('Remaining time')
                    widths.append(20)

                output += self.ui.row(*header2) + '\n'
                output += self.ui.row(*header3, widths=widths) + '\n'
                output += self.ui.row_sep()

            else:
                if self.show_timing:
                    output += self.ui.row('', 'Loss', 'Metric', '', '', widths=[10, 36, 36, 16, 15]) + '\n'
                    output += self.ui.row('', self.loss, self.metric, '') + '\n'
                    output += self.ui.row(
                        'Epoch', 'Train', 'Val', 'Train', 'Val', 'Step', 'Remaining',
                        widths=[10, 18, 18, 18, 18, 16, 15]
                    ) + '\n'
                else:
                    output += self.ui.row('', 'Loss', 'Metric', widths=[10, 36, 36]) + '\n'
                    output += self.ui.row('', self.loss, self.metric) + '\n'
                    output += self.ui.row(
                        'Epoch', 'Train', 'Val', 'Train', 'Val',
                        widths=[10, 18, 18, 18, 18]
                    ) + '\n'
                output += self.ui.row_sep()

            self.output_target.line(output)

            # Show header only once
            self.header_shown = True

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch + 1

        if 'steps' in self.params:
            self.target = self.params['steps']

        elif 'samples' in self.params:
            self.target = self.params['samples']

        self.seen = 0
        self.timer.start()

    def on_batch_begin(self, batch, logs=None):
        if self.target and self.seen < self.target:
            self.log_values = []

    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        batch_size = logs.get('size', 0)
        self.seen += batch_size

        for k in self.params['metrics']:
            if k in logs:
                self.log_values.append((k, logs[k]))

    def on_epoch_end(self, epoch, logs=None):
        self.timer.stop()
        self.total_time += self.timer.elapsed()
        self.epoch = epoch

        if self.first_epoch is None:
            # Store first epoch number
            self.first_epoch = self.epoch

        logs = logs or {}

        # Reset values
        self.most_recent_values['l_tra'] = None
        self.most_recent_values['l_val'] = None
        self.most_recent_values['m_tra'] = None
        self.most_recent_values['m_val'] = None

        # Collect values
        for k in self.params['metrics']:
            if k in logs:
                self.log_values.append((k, logs[k]))
                if k == 'loss':
                    self.data['l_tra'][self.epoch] = logs[k]
                    self.most_recent_values['l_tra'] = '{:4.3f}'.format(logs[k])

                elif k == 'val_loss':
                    self.data['l_val'][self.epoch] = logs[k]
                    self.most_recent_values['l_val'] = '{:4.3f}'.format(logs[k])

                elif self.metric and k.endswith(self.metric):
                    if k.startswith('val_'):
                        self.data['m_val'][self.epoch] = logs[k]
                        self.most_recent_values['m_val'] = '{:4.3f}'.format(logs[k])

                    else:
                        self.data['m_tra'][self.epoch] = logs[k]
                        self.most_recent_values['m_tra'] = '{:4.3f}'.format(logs[k])

        for metric_label in self.external_metric_labels:
            if metric_label in self.external_metric:
                metric_name = self.external_metric_labels[metric_label]
                value = self.external_metric[metric_label]
                if metric_name.endswith('f_measure') or metric_name.endswith('f_score'):
                    self.most_recent_values[metric_label] = '{:3.1f}'.format(value * 100)
                else:
                    self.most_recent_values[metric_label] = '{:4.3f}'.format(value)

        if (not self.manual_update or
           (self.epoch - self.last_update_epoch > 0 and (self.epoch+1) % self.manual_update_interval)):

            # Update logged progress
            self.update_progress_log()

    def update(self):
        """Update

        """
        self.update_progress_log()
        self.last_update_epoch = self.epoch

    def update_progress_log(self):
        """Update progress to logging interface

        """

        if self.epoch - self.last_update_epoch:
            data = [
                self.epoch,
                self.data['l_tra'][self.epoch],
                self.data['l_val'][self.epoch] if 'l_val' in self.most_recent_values else '-',
                self.data['m_tra'][self.epoch],
                self.data['m_val'][self.epoch] if self.most_recent_values['m_val'] else '-'
            ]
            types = [
                'int', 'float4', 'float4', 'float4', 'float4'
            ]
            for metric_label in self.external_metric_labels:
                if metric_label in self.external_metric:
                    value = self.data[metric_label][self.epoch]

                    if numpy.isnan(value):
                        value = ' ' * 10
                    else:
                        if self.external_metric_labels[metric_label].endswith('f_measure') or self.external_metric_labels[metric_label].endswith('f_score'):
                            value = float(value) * 100
                            types.append('float2')

                        else:
                            value = float(value)
                            types.append('float4')

                    data.append(value)

                else:
                    data.append('')

            if self.show_timing:
                # Add step time
                step_time = datetime.timedelta(seconds=self.timer.elapsed())
                data.append(str(step_time)[:-3])
                types.append('str')

                # Add remaining time
                average_time_per_epoch = self.total_time / float(self.epoch-(self.first_epoch-1))
                remaining_time_seconds = datetime.timedelta(
                    seconds=(self.epochs - 1 - self.epoch) * average_time_per_epoch
                )
                data.append('-'+str(remaining_time_seconds).split('.', 2)[0])
                types.append('str')

            output = self.ui.row(*data, types=types)

            self.output_target.line(output)

    def add_external_metric(self, metric_id):
        """Add external metric to be monitored

        Parameters
        ----------
        metric_id : str
            Metric name

        """

        if metric_id not in self.external_metric_labels:
            self.external_metric_labels[metric_id] = metric_id

        if metric_id not in self.data:
            self.data[metric_id] = numpy.empty((self.epochs,))
            self.data[metric_id][:] = numpy.nan

    def set_external_metric_value(self, metric_label, metric_value):
        """Add external metric value

        Parameters
        ----------
        metric_label : str
            Metric label

        metric_value : numeric
            Metric value

        """

        self.external_metric[metric_label] = metric_value
        self.data[metric_label][self.epoch] = metric_value
示例#2
0
def model_summary_string(keras_model, mode='keras', show_parameters=True, display=False):
    """Model summary in a formatted string, similar to Keras model summary function.

    Parameters
    ----------
    keras_model : keras model
        Keras model

    mode : str
        Summary mode ['extended', 'keras']. In case 'keras', standard Keras summary is returned.
        Default value keras

    show_parameters : bool
        Show model parameter count and input / output shapes
        Default value True

    display : bool
        Display summary immediately, otherwise return string
        Default value False

    Returns
    -------
    str
        Model summary

    """

    if is_jupyter():
        ui = FancyHTMLStringifier()
        html_mode = True
    else:
        ui = FancyStringifier()
        html_mode = False

    output = ''
    output += ui.line('Model summary') + '\n'

    if mode == 'extended' or mode == 'extended_wide':
        layer_name_map = {
            'BatchNormalization': 'BatchNorm',
        }

        layer_type_html_tags = {
            'InputLayer': '<span class="label label-default">{0:s}</span>',
            'Dense': '<span class="label label-primary">{0:s}</span>',
            'TimeDistributed': '<span class="label label-primary">{0:s}</span>',

            'BatchNorm': '<span class="label label-default">{0:s}</span>',
            'Activation': '<span class="label label-default">{0:s}</span>',
            'Dropout': '<span class="label label-default">{0:s}</span>',

            'Flatten': '<span class="label label-success">{0:s}</span>',
            'Reshape': '<span class="label label-success">{0:s}</span>',
            'Permute': '<span class="label label-success">{0:s}</span>',

            'Conv1D': '<span class="label label-warning">{0:s}</span>',
            'Conv2D': '<span class="label label-warning">{0:s}</span>',

            'MaxPooling1D': '<span class="label label-success">{0:s}</span>',
            'MaxPooling2D': '<span class="label label-success">{0:s}</span>',
            'MaxPooling3D': '<span class="label label-success">{0:s}</span>',
            'AveragePooling1D': '<span class="label label-success">{0:s}</span>',
            'AveragePooling2D': '<span class="label label-success">{0:s}</span>',
            'AveragePooling3D': '<span class="label label-success">{0:s}</span>',
            'GlobalMaxPooling1D': '<span class="label label-success">{0:s}</span>',
            'GlobalMaxPooling2D': '<span class="label label-success">{0:s}</span>',
            'GlobalMaxPooling3D': '<span class="label label-success">{0:s}</span>',
            'GlobalAveragePooling1D': '<span class="label label-success">{0:s}</span>',
            'GlobalAveragePooling2D': '<span class="label label-success">{0:s}</span>',
            'GlobalAveragePooling3D': '<span class="label label-success">{0:s}</span>',

            'RNN': '<span class="label label-danger">{0:s}</span>',
            'SimpleRNN': '<span class="label label-danger">{0:s}</span>',
            'GRU': '<span class="label label-danger">{0:s}</span>',
            'CuDNNGRU': '<span class="label label-danger">{0:s}</span>',
            'LSTM': '<span class="label label-danger">{0:s}</span>',
            'CuDNNLSTM': '<span class="label label-danger">{0:s}</span>',
            'Bidirectional': '<span class="label label-danger">{0:s}</span>'
        }

        from tensorflow import keras
        from distutils.version import LooseVersion
        import tensorflow.keras.backend as keras_backend

        table_data = {
            'layer_type': [],
            'output': [],
            'parameter_count': [],
            'name': [],
            'connected_to': [],
            'activation': [],
            'initialization': []
        }

        row_separators = []
        prev_name = None
        for layer_id, layer in enumerate(keras_model.layers):
            connections = []
            if LooseVersion(keras.__version__) >= LooseVersion('2.1.3'):
                for node_index, node in enumerate(layer._inbound_nodes):
                    for i in range(len(node.inbound_layers)):
                        inbound_layer = node.inbound_layers[i].name
                        inbound_node_index = node.node_indices[i]
                        inbound_tensor_index = node.tensor_indices[i]
                        connections.append(
                            inbound_layer + '[' + str(inbound_node_index) + '][' + str(inbound_tensor_index) + ']'
                        )

            else:
                for node_index, node in enumerate(layer.inbound_nodes):
                    for i in range(len(node.inbound_layers)):
                        inbound_layer = node.inbound_layers[i].name
                        inbound_node_index = node.node_indices[i]
                        inbound_tensor_index = node.tensor_indices[i]
                        connections.append(
                            inbound_layer + '[' + str(inbound_node_index) + '][' + str(inbound_tensor_index) + ']'
                        )

            config = DictContainer(layer.get_config())
            layer_name = layer.__class__.__name__
            if layer_name in layer_name_map:
                layer_name = layer_name_map[layer_name]

            if html_mode and layer_name in layer_type_html_tags:
                layer_name = layer_type_html_tags[layer_name].format(layer_name)

            if config.get_path('kernel_initializer.class_name') == 'VarianceScaling':
                init = str(config.get_path('kernel_initializer.config.distribution', '---'))

            elif config.get_path('kernel_initializer.class_name') == 'RandomUniform':
                init = 'uniform'

            else:
                init = '-'

            name_parts = layer.name.split('_')
            if prev_name != name_parts[0]:
                row_separators.append(layer_id)
                prev_name = name_parts[0]

            table_data['layer_type'].append(layer_name)
            table_data['output'].append(str(layer.output_shape))
            table_data['parameter_count'].append(str(layer.count_params()))
            table_data['name'].append(layer.name)
            table_data['connected_to'].append(str(connections[0]) if len(connections) > 0 else '-')
            table_data['activation'].append(str(config.get('activation', '-')))
            table_data['initialization'].append(init)

        trainable_count = int(
            numpy.sum([keras_backend.count_params(p) for p in set(keras_model.trainable_weights)])
        )

        non_trainable_count = int(
            numpy.sum([keras_backend.count_params(p) for p in set(keras_model.non_trainable_weights)])
        )

        # Show row separators only if they are useful
        if len(row_separators) == len(keras_model.layers):
            row_separators = None
        if mode == 'extended':
            output += ui.table(
                cell_data=[table_data['name'], table_data['layer_type'], table_data['output'], table_data['parameter_count']],
                column_headers=['Layer name', 'Layer type', 'Output shape', 'Parameters'],
                column_types=['str30', 'str20', 'str25', 'str20'],
                column_separators=[1, 2],
                row_separators=row_separators,
                indent=4
            )

        elif mode == 'extended_wide':
            output += ui.table(
                cell_data=[table_data['name'], table_data['layer_type'], table_data['output'], table_data['parameter_count'],
                           table_data['activation'], table_data['initialization']],
                column_headers=['Layer name', 'Layer type', 'Output shape', 'Parameters', 'Act.', 'Init.'],
                column_types=['str30', 'str20', 'str25', 'str20', 'str15', 'str15'],
                column_separators=[1, 2, 3],
                row_separators=row_separators,
                indent=4
            )

        if show_parameters:
            output += ui.line('') + '\n'
            output += ui.line('Parameters', indent=4) + '\n'
            output += ui.data(indent=6, field='Total', value=trainable_count + non_trainable_count) + '\n'
            output += ui.data(indent=6, field='Trainable', value=trainable_count) + '\n'
            output += ui.data(indent=6, field='Non-Trainable', value=non_trainable_count) + '\n'

    else:
        output_buffer = []
        keras_model.summary(print_fn=output_buffer.append)
        for line in output_buffer:
            if is_jupyter():
                output += ui.line('<code>'+line+'</code>', indent=4) + '\n'
            else:
                output += ui.line(line, indent=4) + '\n'

    model_config = keras_model.get_config()

    if show_parameters:
        output += ui.line('') + '\n'
        output += ui.line('Input', indent=4) + '\n'
        output += ui.data(indent=6, field='Shape', value=keras_model.input_shape) + '\n'

        output += ui.line('Output', indent=4) + '\n'
        output += ui.data(indent=6, field='Shape', value=keras_model.output_shape) + '\n'

        if isinstance(model_config, dict) and 'layers' in model_config:
            output += ui.data(
                indent=6,
                field='Activation',
                value=model_config['layers'][-1]['config'].get('activation')
            ) + '\n'

        elif isinstance(model_config, list):
            output += ui.data(
                indent=6,
                field='Activation',
                value=model_config[-1].get('config', {}).get('activation')
            ) + '\n'

    if display:
        if is_jupyter():
            from IPython.core.display import display, HTML
            display(HTML(output))

        else:
            print(output)

    else:
        return output