def test_activation_on_batch(): checkpoint_path = "examples/iris_tensorflow/model_files/model.ckpt" const_feed_dict_pkl = "examples/iris_tensorflow/model_files/const_feed_dict.pkl" # dict of variables # input = list a = TensorFlowModel(input_nodes="inputs", target_nodes="probas", checkpoint_path=checkpoint_path) a.predict_activation_on_batch(np.ones((3, 4)), layer="logits")
def test_tf_model(): tf.reset_default_graph() input_nodes = "inputs" target_nodes = "preds" meta_graph = "model_files/model.tf.meta" # meta_graph = 'model_files/model.tf-modified.meta' checkpoint = "model_files/model.tf" index = "model_files/model.tf.index" pkl_file = "model_files/const_feed_dict.pkl" from kipoi.model import TensorFlowModel m = TensorFlowModel(input_nodes="inputs", target_nodes="preds", meta_graph=meta_graph, checkpoint=checkpoint, const_feed_dict_pkl=pkl_file) ops = tf.get_default_graph().get_operations() # TODO - modify the out = tf.train.export_meta_graph( filename='model_files/model.tf-modified.meta', as_text=True) ops[0].outputs[0].shape[0] = None pops = [ op.outputs[0] for op in ops if op.type == "Placeholder" and op.name.startswith("Placeholder") ] m.input_ops # view shapes of the data m.target_ops from concise.preprocessing import encodeDNA x = encodeDNA(["T" * m.input_ops.shape[1].value] * 2).astype("float32") out = m.predict_on_batch(x)
def createTensorModel(modelInfo, modelArgs): inputNodes = modelArgs["input_nodes"] targetNodes = modelArgs["target_nodes"] checkpointPath = "./model/" + modelArgs["checkpoint_path"] try: pklPath = "./model/" + modelArgs["const_feed_dict_pkl"] except: pklPath = None model = TensorFlowModel(inputNodes, targetNodes, checkpointPath, const_feed_dict_pkl=pklPath) return model
def test_grad_tens_generation(): import tensorflow as tf checkpoint_path = "examples/iris_tensorflow/model_files/model.ckpt" a = TensorFlowModel(input_nodes="inputs", target_nodes="probas", checkpoint_path=checkpoint_path) fwd_values = a.predict_on_batch(np.ones((3, 4))) assert np.all( a.get_grad_tens(fwd_values, DummySlice()[:, 0:1], "min")[0, :] == np.array([1, 0, 0])) assert np.all( a.get_grad_tens(fwd_values, DummySlice()[:, 0:2], "min")[0, :] == np.array([1, 0, 0])) assert np.all( a.get_grad_tens(fwd_values, DummySlice()[:, 0:2], "max")[0, :] == np.array([0, 1, 0])) assert np.all( a.get_grad_tens(fwd_values, DummySlice()[0:2], "max")[0, :] == a.get_grad_tens( fwd_values, DummySlice()[:, 0:2], "max")[0, :])
def test_loading(): import tensorflow as tf checkpoint_path = "examples/iris_tensorflow/model_files/model.ckpt" const_feed_dict_pkl = "examples/iris_tensorflow/model_files/const_feed_dict.pkl" # dict of variables # input = list a = TensorFlowModel(input_nodes="inputs", target_nodes="probas", checkpoint_path=checkpoint_path) o = a.predict_on_batch(np.ones((3, 4))) assert o.shape == (3, 3) # input = dict a = TensorFlowModel(input_nodes={"out_name": "inputs"}, target_nodes="probas", checkpoint_path=checkpoint_path) with pytest.raises(AssertionError): o = a.predict_on_batch(np.ones((3, 4))) o = a.predict_on_batch({"out_name": np.ones((3, 4))}) assert o.shape == (3, 3) # input = list a = TensorFlowModel(input_nodes=["inputs"], target_nodes="probas", checkpoint_path=checkpoint_path) with pytest.raises(AssertionError): o = a.predict_on_batch(np.ones((3, 4))) o = a.predict_on_batch([np.ones((3, 4))]) assert o.shape == (3, 3) # output = dict a = TensorFlowModel(input_nodes="inputs", target_nodes={"out_name": "probas"}, checkpoint_path=checkpoint_path) o = a.predict_on_batch(np.ones((3, 4))) assert isinstance(o, dict) assert list(o.keys()) == ["out_name"] assert o['out_name'].shape == (3, 3) # output = list a = TensorFlowModel(input_nodes="inputs", target_nodes=["probas"], checkpoint_path=checkpoint_path) o = a.predict_on_batch(np.ones((3, 4))) assert isinstance(o, list) assert len(o) == 1 assert o[0].shape == (3, 3) # test with the extra input a = TensorFlowModel(input_nodes="inputs", target_nodes="probas", checkpoint_path=checkpoint_path, const_feed_dict_pkl=const_feed_dict_pkl) o = a.predict_on_batch(np.ones((3, 4))) assert o.shape == (3, 3)