Exemplo n.º 1
0
    def __init__(self,
                 manual_update=False, epochs=None, external_metric_labels=None,
                 metric=None, loss=None, manual_update_interval=1, output_type='logging', show_timing=True,
                 **kwargs):
        """Constructor

        Parameters
        ----------
        epochs : int
            Total amount of epochs
            Default value None

        metric : str
            Metric name
            Default value None

        manual_update : bool
            Manually update callback, use this to when injecting external metrics
            Default value False

        manual_update_interval : int
            Epoch interval for manual update, used anticipate updates
            Default value 1

        output_type : str
            Output type, either 'logging', 'console', or 'notebook'
            Default value 'logging'

        show_timing : bool
            Show per epoch time and estimated time remaining
            Default value True

        external_metric_labels : dict or OrderedDict
            Dictionary with {'metric_label': 'metric_name'}
            Default value None

        """

        kwargs.update({
            'manual_update': manual_update,
            'epochs': epochs,
            'external_metric_labels': external_metric_labels,
        })

        super(ProgressLoggerCallback, self).__init__(**kwargs)

        if isinstance(metric, str):
            self.metric = metric

        elif callable(metric):
            self.metric = metric.__name__

        self.loss = loss

        self.manual_update_interval = manual_update_interval

        self.output_type = output_type

        self.show_timing = show_timing

        self.timer = Timer()

        self.ui = FancyStringifier()

        if self.output_type == 'logging':
            self.output_target = FancyLogger()

        elif self.output_type == 'console':
            self.output_target = FancyPrinter()

        elif self.output_type == 'notebook':
            self.output_target = FancyHTMLPrinter()
            self.ui = FancyHTMLStringifier()

        self.seen = 0
        self.log_values = []

        self.most_recent_values = collections.OrderedDict()
        self.most_recent_values['l_tra'] = None
        self.most_recent_values['l_val'] = None
        self.most_recent_values['m_tra'] = None
        self.most_recent_values['m_val'] = None

        self.data = {
            'l_tra': numpy.empty((self.epochs,)),
            'l_val': numpy.empty((self.epochs,)),
            'm_tra': numpy.empty((self.epochs,)),
            'm_val': numpy.empty((self.epochs,)),
        }
        self.data['l_tra'][:] = numpy.nan
        self.data['l_val'][:] = numpy.nan
        self.data['m_tra'][:] = numpy.nan
        self.data['m_val'][:] = numpy.nan

        for metric_label in self.external_metric_labels:
            self.data[metric_label] = numpy.empty((self.epochs,))
            self.data[metric_label][:] = numpy.nan

        self.header_shown = False
        self.last_update_epoch = 0

        self.target = None
        self.first_epoch = None
        self.total_time = 0
Exemplo n.º 2
0
class ProgressLoggerCallback(BaseCallback):
    """Keras callback to show metrics in logging interface. Implements Keras Callback API.

    This callback is very similar to standard ``ProgbarLogger`` Keras callback, however it adds support for logging
    interface, and external metrics (metrics calculated outside Keras training process).

    """
    def __init__(self,
                 manual_update=False,
                 epochs=None,
                 external_metric_labels=None,
                 metric=None,
                 loss=None,
                 manual_update_interval=1,
                 output_type='logging',
                 **kwargs):
        """Constructor

        Parameters
        ----------
        epochs : int
            Total amount of epochs
            Default value None

        metric : str
            Metric name
            Default value None

        manual_update : bool
            Manually update callback, use this to when injecting external metrics
            Default value False

        manual_update_interval : int
            Epoch interval for manual update, used anticipate updates
            Default value 1

        output_type : str
            Output type, either 'logging' or 'console'
            Default value 'logging'

        external_metric_labels : dict or OrderedDict
            Dictionary with {'metric_label': 'metric_name'}
            Default value None

        """

        kwargs.update({
            'manual_update': manual_update,
            'epochs': epochs,
            'external_metric_labels': external_metric_labels,
        })

        super(ProgressLoggerCallback, self).__init__(**kwargs)

        if isinstance(metric, str):
            self.metric = metric

        elif callable(metric):
            self.metric = metric.__name__

        self.loss = loss

        self.manual_update_interval = manual_update_interval

        self.output_type = output_type

        self.timer = Timer()
        self.ui = FancyStringifier()

        if self.output_type == 'logging':
            self.output_target = FancyLogger()

        elif self.output_type == 'console':
            self.output_target = FancyPrinter()

        self.seen = 0
        self.log_values = []

        self.most_recent_values = collections.OrderedDict()
        self.most_recent_values['l_tra'] = None
        self.most_recent_values['l_val'] = None
        self.most_recent_values['m_tra'] = None
        self.most_recent_values['m_val'] = None

        self.data = {
            'l_tra': numpy.empty((self.epochs, )),
            'l_val': numpy.empty((self.epochs, )),
            'm_tra': numpy.empty((self.epochs, )),
            'm_val': numpy.empty((self.epochs, )),
        }
        self.data['l_tra'][:] = numpy.nan
        self.data['l_val'][:] = numpy.nan
        self.data['m_tra'][:] = numpy.nan
        self.data['m_val'][:] = numpy.nan

        for metric_label in self.external_metric_labels:
            self.data[metric_label] = numpy.empty((self.epochs, ))
            self.data[metric_label][:] = numpy.nan

        self.header_shown = False
        self.last_update_epoch = 0

        self.target = None

    def on_train_begin(self, logs=None):
        if self.epochs is None:
            self.epochs = self.params['epochs']

        if not self.header_shown:
            output = ''
            output += self.ui.line('Training') + '\n'

            if self.external_metric_labels:
                output += self.ui.row(
                    '',
                    'Loss',
                    'Metric',
                    'Ext. metrics',
                    '',
                    widths=[
                        10, 26, 26,
                        len(self.external_metric_labels) * 15, 17
                    ]) + '\n'

                header2 = ['', self.loss, self.metric]
                header3 = ['Epoch', 'Train', 'Val', 'Train', 'Val']
                widths = [10, 13, 13, 13, 13]
                sep = ['-', '-', '-', '-', '-']
                for metric_label, metric_name in iteritems(
                        self.external_metric_labels):
                    header2.append('')
                    header3.append(metric_name)
                    widths.append(15)
                    sep.append('-')

                header2.append('')
                header3.append('time')
                widths.append(17)
                sep.append('-')

                output += self.ui.row(*header2) + '\n'
                output += self.ui.row(*header3, widths=widths) + '\n'
                output += self.ui.row(*sep)

            else:
                output += self.ui.row(
                    '', 'Loss', 'Metric', '', widths=[10, 26, 26, 17]) + '\n'
                output += self.ui.row('', self.loss, self.metric, '') + '\n'
                output += self.ui.row('Epoch',
                                      'Train',
                                      'Val',
                                      'Train',
                                      'Val',
                                      'Time',
                                      widths=[10, 13, 13, 13, 13, 17]) + '\n'
                output += self.ui.row('-', '-', '-', '-', '-', '-')

            self.output_target.line(output)

            # Show header only once
            self.header_shown = True

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch + 1

        if 'steps' in self.params:
            self.target = self.params['steps']

        elif 'samples' in self.params:
            self.target = self.params['samples']

        self.seen = 0
        self.timer.start()

    def on_batch_begin(self, batch, logs=None):
        if self.target and self.seen < self.target:
            self.log_values = []

    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        batch_size = logs.get('size', 0)
        self.seen += batch_size

        for k in self.params['metrics']:
            if k in logs:
                self.log_values.append((k, logs[k]))

    def on_epoch_end(self, epoch, logs=None):
        self.timer.stop()
        self.epoch = epoch

        logs = logs or {}

        # Reset values
        self.most_recent_values['l_tra'] = None
        self.most_recent_values['l_val'] = None
        self.most_recent_values['m_tra'] = None
        self.most_recent_values['m_val'] = None

        # Collect values
        for k in self.params['metrics']:
            if k in logs:
                self.log_values.append((k, logs[k]))
                if k == 'loss':
                    self.data['l_tra'][self.epoch] = logs[k]
                    self.most_recent_values['l_tra'] = '{:4.3f}'.format(
                        logs[k])

                elif k == 'val_loss':
                    self.data['l_val'][self.epoch] = logs[k]
                    self.most_recent_values['l_val'] = '{:4.3f}'.format(
                        logs[k])

                elif self.metric and k.endswith(self.metric):
                    if k.startswith('val_'):
                        self.data['m_val'][self.epoch] = logs[k]
                        self.most_recent_values['m_val'] = '{:4.3f}'.format(
                            logs[k])

                    else:
                        self.data['m_tra'][self.epoch] = logs[k]
                        self.most_recent_values['m_tra'] = '{:4.3f}'.format(
                            logs[k])

        for metric_label in self.external_metric_labels:
            if metric_label in self.external_metric:
                metric_name = self.external_metric_labels[metric_label]
                value = self.external_metric[metric_label]
                if metric_name.endswith('f_measure') or metric_name.endswith(
                        'f_score'):
                    self.most_recent_values[metric_label] = '{:3.1f}'.format(
                        value * 100)
                else:
                    self.most_recent_values[metric_label] = '{:4.3f}'.format(
                        value)

        if (not self.manual_update
                or (self.epoch - self.last_update_epoch > 0 and
                    (self.epoch + 1) % self.manual_update_interval)):

            # Update logged progress
            self.update_progress_log()

    def update(self):
        """Update

        """
        self.update_progress_log()
        self.last_update_epoch = self.epoch

    def update_progress_log(self):
        """Update progress to logging interface

        """

        if self.epoch - self.last_update_epoch:
            data = [
                self.epoch, self.data['l_tra'][self.epoch],
                self.data['l_val'][self.epoch]
                if 'l_val' in self.most_recent_values else '-',
                self.data['m_tra'][self.epoch], self.data['m_val'][self.epoch]
                if self.most_recent_values['m_val'] else '-'
            ]
            types = ['int', 'float4', 'float4', 'float4', 'float4']
            for metric_label in self.external_metric_labels:
                if metric_label in self.external_metric:
                    value = self.data[metric_label][self.epoch]

                    if numpy.isnan(value):
                        value = ' ' * 10
                    else:
                        if self.external_metric_labels[metric_label].endswith(
                                'f_measure') or self.external_metric_labels[
                                    metric_label].endswith('f_score'):
                            value = float(value) * 100
                            types.append('float2')

                        else:
                            value = float(value)
                            types.append('float4')

                    data.append(value)
                else:
                    data.append('')

            data.append(self.timer.get_string())
            types.append('str')

            output = self.ui.row(*data, types=types)

            self.output_target.line(output)

    def add_external_metric(self, metric_id):
        """Add external metric to be monitored

        Parameters
        ----------
        metric_id : str
            Metric name

        """

        if metric_id not in self.external_metric_labels:
            self.external_metric_labels[metric_id] = metric_id

        if metric_id not in self.data:
            self.data[metric_id] = numpy.empty((self.epochs, ))
            self.data[metric_id][:] = numpy.nan

    def set_external_metric_value(self, metric_label, metric_value):
        """Add external metric value

        Parameters
        ----------
        metric_label : str
            Metric label

        metric_value : numeric
            Metric value

        """

        self.external_metric[metric_label] = metric_value
        self.data[metric_label][self.epoch] = metric_value