def test_layer_database_with_dynamic_shape(self): """ test layer database creation with different input shapes""" # create tf.compat.v1.Session and initialize the weights and biases with zeros config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True graph = tf.Graph() with graph.as_default(): # by default, model will be constructed in default graph input_placeholder = tf.compat.v1.placeholder(tf.float32, [None, None, None, 3], 'input') x = tf.keras.layers.Conv2D(8, (2, 2), padding='SAME')(input_placeholder) x = tf.keras.layers.BatchNormalization(momentum=.3, epsilon=.65)(x) x = tf.keras.layers.Conv2D(8, (1, 1), padding='SAME', activation=tf.nn.tanh)(x) x = tf.keras.layers.BatchNormalization(momentum=.4, epsilon=.25)(x) init = tf.compat.v1.global_variables_initializer() # create session with graph sess = tf.compat.v1.Session(graph=graph, config=config) sess.run(init) layer_db = LayerDatabase(model=sess, input_shape=(1, 224, 224, 3), working_dir=None, starting_ops=['input'], ending_ops=['batch_normalization_1/cond/Merge']) conv1_layer = layer_db.find_layer_by_name('conv2d/Conv2D') conv2_layer = layer_db.find_layer_by_name('conv2d_1/Conv2D') self.assertEqual(conv1_layer.output_shape, [1, 8, 224, 224]) self.assertEqual(conv2_layer.output_shape, [1, 8, 224, 224]) layer_db.destroy() # 2) try with different input shape # create another session with graph sess = tf.compat.v1.Session(graph=graph, config=config) sess.run(init) batch_size = 32 layer_db = LayerDatabase(model=sess, input_shape=(batch_size, 28, 28, 3), working_dir=None, starting_ops=['input'], ending_ops=['batch_normalization_1/cond/Merge']) conv1_layer = layer_db.find_layer_by_name('conv2d/Conv2D') conv2_layer = layer_db.find_layer_by_name('conv2d_1/Conv2D') self.assertEqual(conv1_layer.output_shape, [32, 8, 28, 28]) self.assertEqual(conv2_layer.output_shape, [32, 8, 28, 28]) layer_db.destroy()
def test_layer_database_destroy(self): # create tf.compat.v1.Session and initialize the weights and biases with zeros config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True # create session with graph sess = tf.compat.v1.Session(graph=tf.Graph(), config=config) with sess.graph.as_default(): # by default, model will be constructed in default graph _ = mnist_tf_model.create_model(data_format='channels_last') init = tf.compat.v1.global_variables_initializer() sess.run(init) layer_db = LayerDatabase(model=sess, input_shape=(1, 28, 28, 1), working_dir=None) layer_db.destroy() self.assertRaises(RuntimeError, lambda: sess.run(init)) # delete temp directory shutil.rmtree(str('./temp_meta/'))