示例#1
0
 def test_init_model(self):
     net = demo_init_model()
     inputs = Input((2, ), batch_size=2)
     outputs = net(inputs)
     node = getattr(outputs, '_anchor')[0]
     self.assertEqual(node.layer.name, 'dense3')
     self.assertListEqual(outputs.get_shape().as_list(), [2, 1])
示例#2
0
 def test_weight(self):
     net = demo_init_model()
     net.train()
     inputs = Input((2, ), batch_size=2)
     outputs = net(inputs)
     for w in net.trainable_weights:
         print(w)
示例#3
0
def demo_graph_model():
    net = demo_init_model()
    with tf.name_scope('network') as scope:
        inputs = Input((2, ), batch_size=2)
        outputs = net(inputs)
    net = Network(inputs=inputs, outputs=outputs, name=scope)
    return net
示例#4
0
def demo_loop_graph_model():
    with tf.name_scope('network_top') as scope:
        net = demo_graph_model()
        inputs = Input((2, ), batch_size=2)
        inputs = Dense(2, name='dense')(inputs)
        outputs = net(inputs)
    net = Network(inputs=inputs, outputs=outputs, name=scope)
    return net
示例#5
0
 def test_graph_model(self):
     start = time.time()
     net = demo_graph_model()
     inputs = Input(input_shape=(2, ), batch_size=2)
     outputs = net(inputs)
     print(time.time() - start)
     node = getattr(outputs, '_anchor')[0]
     self.assertEqual(node.layer.name, 'network/')
     self.assertListEqual(outputs.get_shape().as_list(), [2, 1])
示例#6
0
def to_graph_network():
    with graph_scope('network_graph',
                     values=[Input(input_shape=(2, ))]) as handler:
        inputs = handler.inputs
        net = Dense(10, name='dense1')(inputs)
        net = Dense(10, name='dense2')(net)
        net = Dense(1, name='dense3')(net)
        handler.outputs = net
    return net
示例#7
0
 def test_loop_graph_model(self):
     start = time.time()
     net = demo_loop_graph_model()
     inputs = Input((2, ), batch_size=2)
     outputs = net(inputs)
     print(time.time() - start)
     node = getattr(outputs, '_anchor')[0]
     self.assertEqual(node.layer.name, 'network_top/')
     self.assertListEqual(outputs.get_shape().as_list(), [2, 1])
     writer = tf.summary.FileWriter(
         'D:/GeekGank/workspace/graph/model_graph', tf.get_default_graph())
     writer.close()
示例#8
0
 def test_layer(self):
     inputs = Input((2, ))
     net = Dense(10, name='dense')
     outputs = net(inputs)
     self.assertListEqual(outputs.get_shape().as_list(), [1, 10])