コード例 #1
0
 def test_first_conv_block_shapes(self):
     can = vgg.build_can(input_shape=[512, 512, 3], name='can')
     conv0 = can.get_layer(name='can_g_conv0')
     # The following shapes are explicitly described by Zhang et al. in Section 3
     # of the paper.
     self.assertAllEqual(conv0.input.shape, [None, 512, 512, 1475])
     self.assertAllEqual(conv0.output.shape, [None, 512, 512, 64])
コード例 #2
0
ファイル: models.py プロジェクト: guangyusong/google-research
def build_model(model_type, batch_size):
    """Returns a Keras model specified by name."""
    if model_type == 'unet':
        return u_net.get_model(input_shape=(512, 512, 3),
                               scales=4,
                               bottleneck_depth=1024,
                               bottleneck_layers=2)
    elif model_type == 'can':
        return vgg.build_can(input_shape=(512, 512, 3),
                             conv_channels=64,
                             out_channels=3)
    else:
        raise ValueError(model_type)
コード例 #3
0
 def test_contains_named_conv_blocks(self):
     can = vgg.build_can(name='can')
     for i in range(9):
         self.assertIsNotNone(can.get_layer(name=f'can_g_conv{i}'))
     self.assertIsNotNone(can.get_layer(name='can_g_conv_last'))
コード例 #4
0
 def test_output_shape(self):
     x = tf.random.uniform([2, 256, 256, 3], seed=0)
     can = vgg.build_can(input_shape=x.shape[1:])
     y = can(x)
     self.assertAllEqual(y.shape, x.shape)