コード例 #1
0
    def base(self, images, is_training, *args, **kwargs):
        channels_data_format = 'channels_last' if self.data_format == 'NHWC' else 'channels_first'
        lmnet_block = self._get_lmnet_block(is_training, channels_data_format)

        max_pool_with_argmax = functools.partial(self._max_pool_with_argmax,
                                                 ksize=(1, 2, 2, 1),
                                                 strides=(1, 2, 2, 1),
                                                 padding='SAME')

        unpool_with_argmax = functools.partial(self._unpool_with_argmax,
                                               ksize=(1, 2, 2, 1))
        self.images = images

        x = lmnet_block('conv1', images, 32, 3)
        x, i_1 = max_pool_with_argmax(name='pool1', inputs=x)
        x = lmnet_block('conv2', x, 64, 3)
        x, i_2 = max_pool_with_argmax(name='pool2', inputs=x)
        x = lmnet_block('conv3', x, 128, 3)
        x, i_3 = max_pool_with_argmax(name='pool3', inputs=x)
        x = lmnet_block('conv4', x, 256, 3)
        x, i_4 = max_pool_with_argmax(name='pool4', inputs=x)
        x = lmnet_block('conv5', x, 256, 3)
        x, i_5 = max_pool_with_argmax(name='pool5', inputs=x)

        x = unpool_with_argmax(name='unpool6', inputs=x, mask=i_5)
        x = lmnet_block('conv6', x, 256, 3)
        x = unpool_with_argmax(name='unpool7', inputs=x, mask=i_4)
        x = lmnet_block('conv7', x, 128, 3)
        x = unpool_with_argmax(name='unpool8', inputs=x, mask=i_3)
        x = lmnet_block('conv8', x, 64, 3)
        x = unpool_with_argmax(name='unpool9', inputs=x, mask=i_2)
        x = lmnet_block('conv9', x, 32, 3)
        x = unpool_with_argmax(name='unpool10', inputs=x, mask=i_1)
        x = lmnet_block('conv10', x, 32, 3)
        x = lmnet_block('conv11', x, self.num_classes, 3)

        return x
コード例 #2
0
ファイル: lm_segnet_v1.py プロジェクト: smilejx/blueoil
    def base(self, images, is_training, *args, **kwargs):
        channels_data_format = 'channels_last' if self.data_format == 'NHWC' else 'channels_first'
        lmnet_block = self._get_lmnet_block(is_training, channels_data_format)

        self.images = images

        x = lmnet_block('conv1', images, 32, 3)
        x = self._space_to_depth(name='space2depth1', inputs=x)
        x = lmnet_block('conv2', x, 64, 3)
        x = self._space_to_depth(name='space2depth2', inputs=x)
        x = lmnet_block('conv3', x, 128, 3)
        x = self._space_to_depth(name='space2depth3', inputs=x)
        x = lmnet_block('conv4', x, 256, 3)
        x = lmnet_block('conv5', x, 256, 3)
        x = lmnet_block('conv6', x, 256, 3)
        x = lmnet_block('conv7', x, 256, 3)
        x = self._depth_to_space(name='depth2space1', inputs=x)
        x = lmnet_block('conv8', x, 64, 3)
        x = self._depth_to_space(name='depth2space2', inputs=x)
        x = lmnet_block('conv9', x, 32, 3)
        x = self._depth_to_space(name='depth2space3', inputs=x)
        x = lmnet_block('conv10', x, 32, 3)
        x = lmnet_block('conv11', x, self.num_classes, 3)

        return x