def AddModel(model, data):
            conv1 = brew.conv(model,
                              data,
                              'conv1',
                              dim_in=input_channels,
                              dim_out=10,
                              kernel=3,
                              stride=1,
                              pad=1,
                              training_mode=1)
            deconv1 = brew.conv_transpose(model,
                                          conv1,
                                          'deconv1',
                                          dim_in=10,
                                          dim_out=10,
                                          kernel=2,
                                          stride=2,
                                          pad=0,
                                          training_mode=1)
            fc1 = brew.fc(model,
                          deconv1,
                          'fc1',
                          dim_in=10 * 56 * 56,
                          dim_out=3)
            softmax = brew.softmax(model, fc1, 'softmax')

            return softmax
Beispiel #2
0
 def ConvTranspose(self, *args, **kwargs):
     return brew.conv_transpose(
         self,
         *args,
         use_cudnn=self.use_cudnn,
         order=self.order,
         cudnn_exhaustive_search=self.cudnn_exhaustive_search,
         ws_nbytes_limit=self.ws_nbytes_limit,
         **kwargs)
Beispiel #3
0
 def ConvTranspose(self, *args, **kwargs):
     return brew.conv_transpose(
         self,
         *args,
         use_cudnn=self.use_cudnn,
         order=self.order,
         cudnn_exhaustive_search=self.cudnn_exhaustive_search,
         ws_nbytes_limit=self.ws_nbytes_limit,
         **kwargs
     )
Beispiel #4
0
def create_unet_model(m, device_opts, is_test):

    base_n_filters = 16
    kernel_size = 3
    pad = (kernel_size - 1) / 2
    do_dropout = True
    num_classes = 3

    weight_init = ("MSRAFill", {})

    with core.DeviceScope(device_opts):

        contr_1_1 = brew.spatial_bn(m,
                                    brew.relu(
                                        m,
                                        brew.conv(m,
                                                  'data',
                                                  'conv_1_1',
                                                  dim_in=num_classes,
                                                  dim_out=base_n_filters,
                                                  kernel=kernel_size,
                                                  pad=pad,
                                                  weight_init=weight_init),
                                        'nonl_1_1'),
                                    'contr_1_1',
                                    dim_in=base_n_filters,
                                    epsilon=1e-3,
                                    momentum=0.1,
                                    is_test=is_test)
        contr_1_2 = brew.spatial_bn(m,
                                    brew.relu(
                                        m,
                                        brew.conv(m,
                                                  contr_1_1,
                                                  'conv_1_2',
                                                  dim_in=base_n_filters,
                                                  dim_out=base_n_filters,
                                                  kernel=kernel_size,
                                                  pad=pad,
                                                  weight_init=weight_init),
                                        'nonl_1_2'),
                                    'contr_1_2',
                                    dim_in=base_n_filters,
                                    epsilon=1e-3,
                                    momentum=0.1,
                                    is_test=is_test)
        pool1 = brew.max_pool(m, contr_1_2, 'pool1', kernel=2, stride=2)

        contr_2_1 = brew.spatial_bn(m,
                                    brew.relu(
                                        m,
                                        brew.conv(m,
                                                  pool1,
                                                  'conv_2_1',
                                                  dim_in=base_n_filters,
                                                  dim_out=base_n_filters * 2,
                                                  kernel=kernel_size,
                                                  pad=pad,
                                                  weight_init=weight_init),
                                        'nonl_2_1'),
                                    'contr_2_1',
                                    dim_in=base_n_filters * 2,
                                    epsilon=1e-3,
                                    momentum=0.1,
                                    is_test=is_test)
        contr_2_2 = brew.spatial_bn(m,
                                    brew.relu(
                                        m,
                                        brew.conv(m,
                                                  contr_2_1,
                                                  'conv_2_2',
                                                  dim_in=base_n_filters * 2,
                                                  dim_out=base_n_filters * 2,
                                                  kernel=kernel_size,
                                                  pad=pad,
                                                  weight_init=weight_init),
                                        'nonl_2_2'),
                                    'contr_2_2',
                                    dim_in=base_n_filters * 2,
                                    epsilon=1e-3,
                                    momentum=0.1,
                                    is_test=is_test)
        pool2 = brew.max_pool(m, contr_2_2, 'pool2', kernel=2, stride=2)

        contr_3_1 = brew.spatial_bn(m,
                                    brew.relu(
                                        m,
                                        brew.conv(m,
                                                  pool2,
                                                  'conv_3_1',
                                                  dim_in=base_n_filters * 2,
                                                  dim_out=base_n_filters * 4,
                                                  kernel=kernel_size,
                                                  pad=pad,
                                                  weight_init=weight_init),
                                        'nonl_3_1'),
                                    'contr_3_1',
                                    dim_in=base_n_filters * 4,
                                    epsilon=1e-3,
                                    momentum=0.1,
                                    is_test=is_test)
        contr_3_2 = brew.spatial_bn(m,
                                    brew.relu(
                                        m,
                                        brew.conv(m,
                                                  contr_3_1,
                                                  'conv_3_2',
                                                  dim_in=base_n_filters * 4,
                                                  dim_out=base_n_filters * 4,
                                                  kernel=kernel_size,
                                                  pad=pad,
                                                  weight_init=weight_init),
                                        'nonl_3_2'),
                                    'contr_3_2',
                                    dim_in=base_n_filters * 4,
                                    epsilon=1e-3,
                                    momentum=0.1,
                                    is_test=is_test)
        pool3 = brew.max_pool(m, contr_3_2, 'pool3', kernel=2, stride=2)

        contr_4_1 = brew.spatial_bn(m,
                                    brew.relu(
                                        m,
                                        brew.conv(m,
                                                  pool3,
                                                  'conv_4_1',
                                                  dim_in=base_n_filters * 4,
                                                  dim_out=base_n_filters * 8,
                                                  kernel=kernel_size,
                                                  pad=pad,
                                                  weight_init=weight_init),
                                        'nonl_4_1'),
                                    'contr_4_1',
                                    dim_in=base_n_filters * 8,
                                    epsilon=1e-3,
                                    momentum=0.1,
                                    is_test=is_test)
        contr_4_2 = brew.spatial_bn(m,
                                    brew.relu(
                                        m,
                                        brew.conv(m,
                                                  contr_4_1,
                                                  'conv_4_2',
                                                  dim_in=base_n_filters * 8,
                                                  dim_out=base_n_filters * 8,
                                                  kernel=kernel_size,
                                                  pad=pad,
                                                  weight_init=weight_init),
                                        'nonl_4_2'),
                                    'contr_4_2',
                                    dim_in=base_n_filters * 8,
                                    epsilon=1e-3,
                                    momentum=0.1,
                                    is_test=is_test)
        pool4 = brew.max_pool(m, contr_4_2, 'pool4', kernel=2, stride=2)

        if do_dropout:
            pool4 = brew.dropout(m, pool4, 'drop', ratio=0.4)

        encode_5_1 = brew.spatial_bn(m,
                                     brew.relu(
                                         m,
                                         brew.conv(m,
                                                   pool4,
                                                   'conv_5_1',
                                                   dim_in=base_n_filters * 8,
                                                   dim_out=base_n_filters * 16,
                                                   kernel=kernel_size,
                                                   pad=pad,
                                                   weight_init=weight_init),
                                         'nonl_5_1'),
                                     'encode_5_1',
                                     dim_in=base_n_filters * 16,
                                     epsilon=1e-3,
                                     momentum=0.1,
                                     is_test=is_test)
        encode_5_2 = brew.spatial_bn(m,
                                     brew.relu(
                                         m,
                                         brew.conv(m,
                                                   encode_5_1,
                                                   'conv_5_2',
                                                   dim_in=base_n_filters * 16,
                                                   dim_out=base_n_filters * 16,
                                                   kernel=kernel_size,
                                                   pad=pad,
                                                   weight_init=weight_init),
                                         'nonl_5_2'),
                                     'encode_5_2',
                                     dim_in=base_n_filters * 16,
                                     epsilon=1e-3,
                                     momentum=0.1,
                                     is_test=is_test)
        upscale5 = brew.conv_transpose(m,
                                       encode_5_2,
                                       'upscale5',
                                       dim_in=base_n_filters * 16,
                                       dim_out=base_n_filters * 16,
                                       kernel=2,
                                       stride=2,
                                       weight_init=weight_init)

        concat6 = brew.concat(m, [upscale5, contr_4_2], 'concat6')  #, axis=1)
        expand_6_1 = brew.spatial_bn(
            m,
            brew.relu(
                m,
                brew.conv(m,
                          concat6,
                          'conv_6_1',
                          dim_in=base_n_filters * 8 * 3,
                          dim_out=base_n_filters * 8,
                          kernel=kernel_size,
                          pad=pad,
                          weight_init=weight_init), 'nonl_6_1'),
            'expand_6_1',
            dim_in=base_n_filters * 8,
            epsilon=1e-3,
            momentum=0.1,
            is_test=is_test)
        expand_6_2 = brew.spatial_bn(m,
                                     brew.relu(
                                         m,
                                         brew.conv(m,
                                                   expand_6_1,
                                                   'conv_6_2',
                                                   dim_in=base_n_filters * 8,
                                                   dim_out=base_n_filters * 8,
                                                   kernel=kernel_size,
                                                   pad=pad,
                                                   weight_init=weight_init),
                                         'nonl_6_2'),
                                     'expand_6_2',
                                     dim_in=base_n_filters * 8,
                                     epsilon=1e-3,
                                     momentum=0.1,
                                     is_test=is_test)
        upscale6 = brew.conv_transpose(m,
                                       expand_6_2,
                                       'upscale6',
                                       dim_in=base_n_filters * 8,
                                       dim_out=base_n_filters * 8,
                                       kernel=2,
                                       stride=2,
                                       weight_init=weight_init)

        concat7 = brew.concat(m, [upscale6, contr_3_2], 'concat7')
        expand_7_1 = brew.spatial_bn(
            m,
            brew.relu(
                m,
                brew.conv(m,
                          concat7,
                          'conv_7_1',
                          dim_in=base_n_filters * 4 * 3,
                          dim_out=base_n_filters * 4,
                          kernel=kernel_size,
                          pad=pad,
                          weight_init=weight_init), 'nonl_7_1'),
            'expand_7_1',
            dim_in=base_n_filters * 4,
            epsilon=1e-3,
            momentum=0.1,
            is_test=is_test)
        expand_7_2 = brew.spatial_bn(m,
                                     brew.relu(
                                         m,
                                         brew.conv(m,
                                                   expand_7_1,
                                                   'conv_7_2',
                                                   dim_in=base_n_filters * 4,
                                                   dim_out=base_n_filters * 4,
                                                   kernel=kernel_size,
                                                   pad=pad,
                                                   weight_init=weight_init),
                                         'nonl_7_2'),
                                     'expand_7_2',
                                     dim_in=base_n_filters * 4,
                                     epsilon=1e-3,
                                     momentum=0.1,
                                     is_test=is_test)
        upscale7 = brew.conv_transpose(m,
                                       expand_7_2,
                                       'upscale7',
                                       dim_in=base_n_filters * 4,
                                       dim_out=base_n_filters * 4,
                                       kernel=2,
                                       stride=2,
                                       weight_init=weight_init)

        concat8 = brew.concat(m, [upscale7, contr_2_2], 'concat8')
        expand_8_1 = brew.spatial_bn(
            m,
            brew.relu(
                m,
                brew.conv(m,
                          concat8,
                          'conv_8_1',
                          dim_in=base_n_filters * 2 * 3,
                          dim_out=base_n_filters * 2,
                          kernel=kernel_size,
                          pad=pad,
                          weight_init=weight_init), 'nonl_8_1'),
            'expand_8_1',
            dim_in=base_n_filters * 2,
            epsilon=1e-3,
            momentum=0.1,
            is_test=is_test)
        expand_8_2 = brew.spatial_bn(m,
                                     brew.relu(
                                         m,
                                         brew.conv(m,
                                                   expand_8_1,
                                                   'conv_8_2',
                                                   dim_in=base_n_filters * 2,
                                                   dim_out=base_n_filters * 2,
                                                   kernel=kernel_size,
                                                   pad=pad,
                                                   weight_init=weight_init),
                                         'nonl_8_2'),
                                     'expand_8_2',
                                     dim_in=base_n_filters * 2,
                                     epsilon=1e-3,
                                     momentum=0.1,
                                     is_test=is_test)
        upscale8 = brew.conv_transpose(m,
                                       expand_8_2,
                                       'upscale8',
                                       dim_in=base_n_filters * 2,
                                       dim_out=base_n_filters * 2,
                                       kernel=2,
                                       stride=2,
                                       weight_init=weight_init)

        concat9 = brew.concat(m, [upscale8, contr_1_2], 'concat9')
        expand_9_1 = brew.spatial_bn(m,
                                     brew.relu(
                                         m,
                                         brew.conv(m,
                                                   concat9,
                                                   'conv_9_1',
                                                   dim_in=base_n_filters * 3,
                                                   dim_out=base_n_filters,
                                                   kernel=kernel_size,
                                                   pad=pad,
                                                   weight_init=weight_init),
                                         'nonl_9_1'),
                                     'expand_9_1',
                                     dim_in=base_n_filters,
                                     epsilon=1e-3,
                                     momentum=0.1,
                                     is_test=is_test)
        expand_9_2 = brew.spatial_bn(m,
                                     brew.relu(
                                         m,
                                         brew.conv(m,
                                                   expand_9_1,
                                                   'conv_9_2',
                                                   dim_in=base_n_filters,
                                                   dim_out=base_n_filters,
                                                   kernel=kernel_size,
                                                   pad=pad,
                                                   weight_init=weight_init),
                                         'nonl_9_2'),
                                     'expand_9_2',
                                     dim_in=base_n_filters,
                                     epsilon=1e-3,
                                     momentum=0.1,
                                     is_test=is_test)

        output_segmentation = brew.conv(m,
                                        expand_9_2,
                                        'output_segmentation',
                                        dim_in=base_n_filters,
                                        dim_out=num_classes,
                                        kernel=1,
                                        pad=0,
                                        stride=1,
                                        weight_init=weight_init)
        m.net.AddExternalOutput(output_segmentation)

        output_sigmoid = m.Sigmoid(output_segmentation, 'output_sigmoid')
        m.net.AddExternalOutput(output_sigmoid)

        return output_segmentation