Example #1
0
def save_lucid_model(config, params, *, model_path, metadata_path):
    config = config.copy()
    config.pop("num_envs")
    library = config.get("library", "baselines")
    venv = create_env(1, **config)
    arch = get_arch(**config)

    with tf.Graph().as_default(), tf.Session() as sess:
        observation_space = venv.observation_space
        observations_placeholder = tf.placeholder(shape=(None, ) +
                                                  observation_space.shape,
                                                  dtype=tf.float32)

        if library == "baselines":
            from baselines.common.policies import build_policy

            with tf.variable_scope("ppo2_model", reuse=tf.AUTO_REUSE):
                policy_fn = build_policy(venv, arch)
                policy = policy_fn(
                    nbatch=None,
                    nsteps=1,
                    sess=sess,
                    observ_placeholder=(observations_placeholder * 255),
                )
                pd = policy.pd
                vf = policy.vf

        else:
            raise ValueError(f"Unsupported library: {library}")

        load_params(params, sess=sess)

        Model.save(
            model_path,
            input_name=observations_placeholder.op.name,
            output_names=[pd.logits.op.name, vf.op.name],
            image_shape=observation_space.shape,
            image_value_range=[0.0, 1.0],
        )

    metadata = {
        "policy_logits_name": pd.logits.op.name,
        "value_function_name": vf.op.name,
        "env_name": config.get("env_name"),
        "gae_gamma": config.get("gamma"),
        "gae_lambda": config.get("lambda"),
    }
    env = venv
    while hasattr(env, "env") and (not hasattr(env, "combos")):
        env = env.env
    if hasattr(env, "combos"):
        metadata["action_combos"] = env.combos
    else:
        metadata["action_combos"] = None

    save_joblib(metadata, metadata_path)
    return {
        "model_bytes": read(model_path, cache=False, mode="rb"),
        **metadata
    }
Example #2
0
def lucid_model_factory(pb_model_path=None,
                        model_image_shape=IMAGE_SHAPE,
                        model_input_name='dense_input',
                        model_output_name='dense_4/Softmax',
                        model_image_value_range=(0, 1)):
    """Build Lucid model object."""

    if pb_model_path is None:
        _, pb_model_path = mkstemp(suffix='.pb')

    # Model.suggest_save_args()

    # Save tf.keras model in pb format
    # https://www.tensorflow.org/guide/saved_model
    Model.save(pb_model_path,
               image_shape=model_image_shape,
               input_name=model_input_name,
               output_names=[model_output_name],
               image_value_range=model_image_value_range)

    class MyLucidModel(Model):
        model_path = pb_model_path
        # labels_path = './lucid/mnist.txt'
        # synsets_path = 'gs://modelzoo/labels/ImageNet_standard_synsets.txt'
        # dataset = 'ImageNet'
        image_shape = model_image_shape
        # is_BGR = True
        image_value_range = model_image_value_range
        input_name = model_input_name

    lucid_model = MyLucidModel()
    lucid_model.load_graphdef()

    return lucid_model
Example #3
0
 def _create_model(self, pb_path, image_shape, image_value_range,
                   input_name):
     model = Model()
     model.model_path = pb_path
     model.image_shape = image_shape
     model.image_value_range = image_value_range
     model.input_name = input_name
     model.load_graphdef()
     return model
Example #4
0
def test_suggest_save_args_existing_graphs(capsys, model_class):
    graph_def = model_class().graph_def

    if model_class == InceptionV1:  # has flexible input shape, can't be inferred
        with pytest.warns(UserWarning):
            inferred = Model.suggest_save_args(graph_def)
    else:
        inferred = Model.suggest_save_args(graph_def)

    assert model_class.input_name == inferred["input_name"]

    if model_class != InceptionV1:
        assert model_class.image_shape == inferred["image_shape"]

    layer_names = [layer.name for layer in model_class.layers]
    for output_name in list(inferred["output_names"]):
        assert output_name in layer_names
Example #5
0
def test_suggest_save_args_happy_path(capsys, minimodel):
    path = "./tests/fixtures/minigraph.pb"

    with tf.Graph().as_default() as graph, tf.Session() as sess:
        _ = minimodel()
        sess.run(tf.global_variables_initializer())

        # ask for suggested arguments
        inferred = Model.suggest_save_args()
        # they should be both printed...
        captured = capsys.readouterr().out  # captures stdout
        names = ["input_name", "image_shape", "output_names"]
        assert all(name in captured for name in names)
        #...and returned

        # check that these inferred values work
        inferred.update(image_value_range=(0, 1))
        Model.save(path, **inferred)
        loaded_model = Model.load(path)
        assert "0.100" in repr(loaded_model.graph_def)
Example #6
0
def test_suggest_save_args_int_input(capsys, minimodel):
    with tf.Graph().as_default() as graph, tf.Session() as sess:
        image_t = tf.placeholder(tf.uint8, shape=(32, 32, 3), name="input")
        input_t = tf.math.divide(image_t,
                                 tf.constant(255, dtype=tf.uint8),
                                 name="divide")
        _ = minimodel(input_t)
        sess.run(tf.global_variables_initializer())

        # ask for suggested arguments
        inferred = Model.suggest_save_args()
        captured = capsys.readouterr().out  # captures stdout
        assert "DT_UINT8" in captured
        assert inferred["input_name"] == "divide"
Example #7
0
def get_model(model_bytes):
    model_fd, model_path = tempfile.mkstemp(suffix=".model.pb")
    with open(model_fd, "wb") as model_file:
        model_file.write(model_bytes)
    return Model.load(model_path)
Example #8
0
def test_Model_save(minimodel):
    with tf.Session().as_default() as sess:
        _ = minimodel()
        sess.run(tf.global_variables_initializer())
        path = "./tests/fixtures/minigraph.pb"
        Model.save(path, "input", ["output"], shape, [0, 1])
Example #9
0
def test_Model_load():
    path = "./tests/fixtures/minigraph.pb"
    model = Model.load(path)
    assert all(
        str(shape[i]) in repr(model.graph_def) for i in range(len(shape)))