예제 #1
0
    def test_extract_model_metrics(self):
        # saving_utils.extract_model_metrics is used in V1 only API
        # keras.experimental.export_saved_model.
        with tf.Graph().as_default():
            a = keras.layers.Input(shape=(3,), name="input_a")
            b = keras.layers.Input(shape=(3,), name="input_b")

            dense = keras.layers.Dense(4, name="dense")
            c = dense(a)
            d = dense(b)
            e = keras.layers.Dropout(0.5, name="dropout")(c)

            model = keras.models.Model([a, b], [d, e])
            extract_metrics = saving_utils.extract_model_metrics(model)
            self.assertEqual(None, extract_metrics)

            extract_metric_names = [
                "dense_binary_accuracy",
                "dropout_binary_accuracy",
                "dense_mean_squared_error",
                "dropout_mean_squared_error",
            ]
            if tf.__internal__.tf2.enabled():
                extract_metric_names.extend(["dense_mae", "dropout_mae"])
            else:
                extract_metric_names.extend(
                    ["dense_mean_absolute_error", "dropout_mean_absolute_error"]
                )

            model_metric_names = [
                "loss",
                "dense_loss",
                "dropout_loss",
            ] + extract_metric_names
            model.compile(
                loss="mae",
                metrics=[
                    keras.metrics.BinaryAccuracy(),
                    "mae",
                    keras.metrics.mean_squared_error,
                ],
                optimizer=tf.compat.v1.train.RMSPropOptimizer(
                    learning_rate=0.01
                ),
            )
            extract_metrics = saving_utils.extract_model_metrics(model)
            self.assertEqual(set(model_metric_names), set(model.metrics_names))
            self.assertEqual(
                set(extract_metric_names), set(extract_metrics.keys())
            )
def _create_signature_def_map(model, mode):
    """Creates a SignatureDef map from a Keras model."""
    inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)}
    if model.optimizer:
        targets_dict = {
            x.name.split(":")[0]: x
            for x in model._targets if x is not None
        }  # pylint: disable=protected-access
        inputs_dict.update(targets_dict)
    outputs_dict = {
        name: x
        for name, x in zip(model.output_names, model.outputs)
    }
    metrics = saving_utils.extract_model_metrics(model)

    # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables
    # are by default not added to any collections. We are doing this here, so
    # that metric variables get initialized.
    local_vars = set(
        tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOCAL_VARIABLES))
    vars_to_add = set()
    if metrics is not None:
        for key, value in metrics.items():
            if isinstance(value, metrics_lib.Metric):
                vars_to_add.update(value.variables)
                # Convert Metric instances to (value_tensor, update_op) tuple.
                metrics[key] = (value.result(), value.updates[0])
    # Remove variables that are in the local variables collection already.
    vars_to_add = vars_to_add.difference(local_vars)
    for v in vars_to_add:
        tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOCAL_VARIABLES,
                                       v)

    export_outputs = model_utils.export_outputs_for_mode(
        mode,
        predictions=outputs_dict,
        loss=model.total_loss if model.optimizer else None,
        metrics=metrics,
    )
    return model_utils.build_all_signature_defs(
        inputs_dict,
        export_outputs=export_outputs,
        serving_only=(mode == mode_keys.ModeKeys.PREDICT),
    )