예제 #1
0
def test_loading():
    import tensorflow as tf
    checkpoint_path = "examples/iris_tensorflow/model_files/model.ckpt"
    const_feed_dict_pkl = "examples/iris_tensorflow/model_files/const_feed_dict.pkl"

    # dict of variables
    # input = list
    a = TensorFlowModel(input_nodes="inputs",
                        target_nodes="probas",
                        checkpoint_path=checkpoint_path)
    o = a.predict_on_batch(np.ones((3, 4)))
    assert o.shape == (3, 3)

    # input = dict
    a = TensorFlowModel(input_nodes={"out_name": "inputs"},
                        target_nodes="probas",
                        checkpoint_path=checkpoint_path)
    with pytest.raises(AssertionError):
        o = a.predict_on_batch(np.ones((3, 4)))

    o = a.predict_on_batch({"out_name": np.ones((3, 4))})
    assert o.shape == (3, 3)

    # input = list
    a = TensorFlowModel(input_nodes=["inputs"],
                        target_nodes="probas",
                        checkpoint_path=checkpoint_path)
    with pytest.raises(AssertionError):
        o = a.predict_on_batch(np.ones((3, 4)))

    o = a.predict_on_batch([np.ones((3, 4))])
    assert o.shape == (3, 3)

    # output = dict
    a = TensorFlowModel(input_nodes="inputs",
                        target_nodes={"out_name": "probas"},
                        checkpoint_path=checkpoint_path)
    o = a.predict_on_batch(np.ones((3, 4)))
    assert isinstance(o, dict)
    assert list(o.keys()) == ["out_name"]
    assert o['out_name'].shape == (3, 3)

    # output = list
    a = TensorFlowModel(input_nodes="inputs",
                        target_nodes=["probas"],
                        checkpoint_path=checkpoint_path)
    o = a.predict_on_batch(np.ones((3, 4)))
    assert isinstance(o, list)
    assert len(o) == 1
    assert o[0].shape == (3, 3)

    # test with the extra input
    a = TensorFlowModel(input_nodes="inputs",
                        target_nodes="probas",
                        checkpoint_path=checkpoint_path,
                        const_feed_dict_pkl=const_feed_dict_pkl)
    o = a.predict_on_batch(np.ones((3, 4)))
    assert o.shape == (3, 3)
예제 #2
0
def test_grad_tens_generation():
    import tensorflow as tf
    checkpoint_path = "examples/iris_tensorflow/model_files/model.ckpt"
    a = TensorFlowModel(input_nodes="inputs",
                        target_nodes="probas",
                        checkpoint_path=checkpoint_path)
    fwd_values = a.predict_on_batch(np.ones((3, 4)))
    assert np.all(
        a.get_grad_tens(fwd_values,
                        DummySlice()[:,
                                     0:1], "min")[0, :] == np.array([1, 0, 0]))
    assert np.all(
        a.get_grad_tens(fwd_values,
                        DummySlice()[:,
                                     0:2], "min")[0, :] == np.array([1, 0, 0]))
    assert np.all(
        a.get_grad_tens(fwd_values,
                        DummySlice()[:,
                                     0:2], "max")[0, :] == np.array([0, 1, 0]))
    assert np.all(
        a.get_grad_tens(fwd_values,
                        DummySlice()[0:2], "max")[0, :] == a.get_grad_tens(
                            fwd_values,
                            DummySlice()[:, 0:2], "max")[0, :])
예제 #3
0
def test_tf_model():
    tf.reset_default_graph()
    input_nodes = "inputs"
    target_nodes = "preds"
    meta_graph = "model_files/model.tf.meta"
    # meta_graph = 'model_files/model.tf-modified.meta'
    checkpoint = "model_files/model.tf"
    index = "model_files/model.tf.index"
    pkl_file = "model_files/const_feed_dict.pkl"

    from kipoi.model import TensorFlowModel

    m = TensorFlowModel(input_nodes="inputs",
                        target_nodes="preds",
                        meta_graph=meta_graph,
                        checkpoint=checkpoint,
                        const_feed_dict_pkl=pkl_file)
    ops = tf.get_default_graph().get_operations()

    # TODO - modify the
    out = tf.train.export_meta_graph(
        filename='model_files/model.tf-modified.meta', as_text=True)
    ops[0].outputs[0].shape[0] = None

    pops = [
        op.outputs[0] for op in ops
        if op.type == "Placeholder" and op.name.startswith("Placeholder")
    ]

    m.input_ops  # view shapes of the data
    m.target_ops

    from concise.preprocessing import encodeDNA

    x = encodeDNA(["T" * m.input_ops.shape[1].value] * 2).astype("float32")
    out = m.predict_on_batch(x)