示例#1
0
    def _update_metrics_names(self):
        """This is a small hack to fix the metric names."""

        i = 1
        for i, name in enumerate(self.output_names, i):
            self.metrics_names[i] = name + '_loss'

        nested_metrics = keras_training.collect_metrics(
            self.metrics, self.output_names)

        for name, output_metrics in zip(self.output_names, nested_metrics):
            for metric in output_metrics:
                i += 1

                if metric == 'accuracy' or metric == 'acc':
                    self.metrics_names[i] = name + '_acc'
                else:
                    metric_fn = keras.metrics.get(metric)
                    self.metrics_names[i] = name + '_' + metric_fn.__name__
示例#2
0
def compile_tfrecord(train_model,
                     optimizer,
                     loss,
                     out_tensor_lst,
                     metrics=[],
                     loss_weights=None):
    train_model.build(train_model)
    # train_model.build()

    train_model.optimizer = optimizers.get(optimizer)
    train_model.loss = loss
    train_model.loss_weights = loss_weights

    # prepare loss weights
    if loss_weights is None:
        loss_weights_list = [1. for _ in range(len(train_model.outputs))]
    elif isinstance(loss_weights, dict):
        for name in loss_weights:
            if name not in train_model.output_names:
                raise ValueError('Unknown entry in loss_weights '
                                 'dictionary: "' + name + '". '
                                 'Only expected the following keys: ' +
                                 str(train_model.output_names))
        loss_weights_list = []
        for name in train_model.output_names:
            loss_weights_list.append(loss_weights.get(name, 1.))
    elif isinstance(loss_weights, list):
        if len(loss_weights) != len(train_model.outputs):
            raise ValueError('When passing a list as loss_weights, '
                             'it should have one entry per model outputs. '
                             'The model has ' + str(len(train_model.outputs)) +
                             ' outputs, but you passed loss_weights=' +
                             str(loss_weights))
        loss_weights_list = loss_weights
    else:
        raise TypeError('Could not interpret loss_weights argument: ' +
                        str(loss_weights) + ' - expected a list of dicts.')

    # prepare loss functions
    if isinstance(loss, dict):
        for name in loss:
            if name not in train_model.output_names:
                raise ValueError('Unknown entry in loss '
                                 'dictionary: "' + name + '". '
                                 'Only expected the following keys: ' +
                                 str(train_model.output_names))
        loss_functions = []
        for name in train_model.output_names:
            if name not in loss:
                raise ValueError('Output "' + name +
                                 '" missing from loss dictionary.')
            loss_functions.append(objectives.get(loss[name]))
    elif isinstance(loss, list):
        if len(loss) != len(train_model.outputs):
            raise ValueError('When passing a list as loss, '
                             'it should have one entry per model outputs. '
                             'The model has ' + str(len(train_model.outputs)) +
                             ' outputs, but you passed loss=' + str(loss))
        loss_functions = [objectives.get(l) for l in loss]
    else:
        loss_function = objectives.get(loss)
        loss_functions = [
            loss_function for _ in range(len(train_model.outputs))
        ]
    train_model.loss_functions = loss_functions
    weighted_losses = [weighted_objective(fn) for fn in loss_functions]

    # prepare metrics
    train_model.metrics = metrics
    train_model.metrics_names = ['loss']
    train_model.metrics_tensors = []

    # compute total loss
    total_loss = None
    for i in range(len(train_model.outputs)):
        y_true = out_tensor_lst[i]
        y_pred = train_model.outputs[i]
        _loss = loss_functions[i]
        # _loss = weighted_losses[i]
        loss_weight = loss_weights_list[i]
        # output_loss = _loss(y_true, y_pred, None, None)
        output_loss = K.mean(_loss(y_true, y_pred))
        if len(train_model.outputs) > 1:
            train_model.metrics_tensors.append(output_loss)
            train_model.metrics_names.append(train_model.output_names[i] +
                                             '_loss')
        if total_loss is None:
            total_loss = loss_weight * output_loss
        else:
            total_loss += loss_weight * output_loss

    # add regularization penalties
    # and other layer-specific losses
    for loss_tensor in train_model.losses:
        total_loss += loss_tensor

    # list of same size as output_names.
    # contains tuples (metrics for output, names of metrics)
    nested_metrics = collect_metrics(metrics, train_model.output_names)

    def append_metric(layer_num, metric_name, metric_tensor):
        """Helper function, used in loop below"""
        if len(train_model.output_names) > 1:
            metric_name = train_model.output_layers[
                layer_num].name + '_' + metric_name

        train_model.metrics_names.append(metric_name)
        train_model.metrics_tensors.append(metric_tensor)

    for i in range(len(train_model.outputs)):
        y_true = out_tensor_lst[i]
        y_pred = train_model.outputs[i]
        output_metrics = nested_metrics[i]

        for metric in output_metrics:
            if metric == 'accuracy' or metric == 'acc':
                # custom handling of accuracy
                # (because of class mode duality)
                output_shape = train_model.internal_output_shapes[i]
                acc_fn = None
                if output_shape[-1] == 1 or train_model.loss_functions[
                        i] == objectives.binary_crossentropy:
                    # case: binary accuracy
                    acc_fn = metrics_module.binary_accuracy
                elif train_model.loss_functions[
                        i] == objectives.sparse_categorical_crossentropy:
                    # case: categorical accuracy with sparse targets
                    acc_fn = metrics_module.sparse_categorical_accuracy
                else:
                    acc_fn = metrics_module.categorical_accuracy

                append_metric(i, 'acc', acc_fn(y_true, y_pred))
            else:
                metric_fn = metrics_module.get(metric)
                metric_result = metric_fn(y_true, y_pred)

                if not isinstance(metric_result, dict):
                    metric_result = {metric_fn.__name__: metric_result}

                for name, tensor in six.iteritems(metric_result):
                    append_metric(i, name, tensor)

    # prepare gradient updates and state updates
    train_model.optimizer = optimizers.get(optimizer)
    train_model.total_loss = total_loss

    train_model.train_function = None
    train_model.test_function = None
    train_model.predict_function = None

    # collected trainable weights and sort them deterministically.
    trainable_weights = train_model.trainable_weights
    # Sort weights by name
    trainable_weights.sort(key=lambda x: x.name)
    train_model._collected_trainable_weights = trainable_weights