def test_init_model(self): net = demo_init_model() inputs = Input((2, ), batch_size=2) outputs = net(inputs) node = getattr(outputs, '_anchor')[0] self.assertEqual(node.layer.name, 'dense3') self.assertListEqual(outputs.get_shape().as_list(), [2, 1])
def test_weight(self): net = demo_init_model() net.train() inputs = Input((2, ), batch_size=2) outputs = net(inputs) for w in net.trainable_weights: print(w)
def demo_graph_model(): net = demo_init_model() with tf.name_scope('network') as scope: inputs = Input((2, ), batch_size=2) outputs = net(inputs) net = Network(inputs=inputs, outputs=outputs, name=scope) return net
def demo_loop_graph_model(): with tf.name_scope('network_top') as scope: net = demo_graph_model() inputs = Input((2, ), batch_size=2) inputs = Dense(2, name='dense')(inputs) outputs = net(inputs) net = Network(inputs=inputs, outputs=outputs, name=scope) return net
def test_graph_model(self): start = time.time() net = demo_graph_model() inputs = Input(input_shape=(2, ), batch_size=2) outputs = net(inputs) print(time.time() - start) node = getattr(outputs, '_anchor')[0] self.assertEqual(node.layer.name, 'network/') self.assertListEqual(outputs.get_shape().as_list(), [2, 1])
def to_graph_network(): with graph_scope('network_graph', values=[Input(input_shape=(2, ))]) as handler: inputs = handler.inputs net = Dense(10, name='dense1')(inputs) net = Dense(10, name='dense2')(net) net = Dense(1, name='dense3')(net) handler.outputs = net return net
def test_loop_graph_model(self): start = time.time() net = demo_loop_graph_model() inputs = Input((2, ), batch_size=2) outputs = net(inputs) print(time.time() - start) node = getattr(outputs, '_anchor')[0] self.assertEqual(node.layer.name, 'network_top/') self.assertListEqual(outputs.get_shape().as_list(), [2, 1]) writer = tf.summary.FileWriter( 'D:/GeekGank/workspace/graph/model_graph', tf.get_default_graph()) writer.close()
def test_layer(self): inputs = Input((2, )) net = Dense(10, name='dense') outputs = net(inputs) self.assertListEqual(outputs.get_shape().as_list(), [1, 10])