Exemplo n.º 1
0
    def add_channelwise_conv(self,
                             in_filters,
                             out_filters,
                             kernels,
                             strides=[1, 1, 1] * 1,
                             pads=[0, 0, 0] * 2,
                             suffix=''):
        self.comp_idx += 1

        assert in_filters == out_filters
        self.prev_blob = brew.conv_nd(
            self.model,
            self.prev_blob,
            '%scomp_%d_conv_%d%s' %
            (self.prefix, self.comp_count, self.comp_idx, suffix),
            in_filters,
            out_filters,
            weight_init=("MSRAFill", {}),
            kernel=kernels,
            strides=strides,
            pads=pads,
            group=in_filters,
            no_bias=self.no_bias,
            use_cudnn=False,
            engine="CHANNELWISE_3D",
        )
        return self.prev_blob
Exemplo n.º 2
0
 def ConvNd(self, *args, **kwargs):
     return brew.conv_nd(
         self,
         *args,
         use_cudnn=self.use_cudnn,
         order=self.order,
         cudnn_exhaustive_search=self.cudnn_exhaustive_search,
         ws_nbytes_limit=self.ws_nbytes_limit,
         **kwargs)
Exemplo n.º 3
0
 def ConvNd(self, *args, **kwargs):
     return brew.conv_nd(
         self,
         *args,
         use_cudnn=self.use_cudnn,
         order=self.order,
         cudnn_exhaustive_search=self.cudnn_exhaustive_search,
         ws_nbytes_limit=self.ws_nbytes_limit,
         **kwargs
     )
Exemplo n.º 4
0
 def add_conv(
     self,
     in_filters,
     out_filters,
     kernels,
     strides=[1, 1, 1] * 1,
     pads=[0, 0, 0] * 2,
     block_type='3d',  # set this to be '3d', '2.5d', or 'track',
     group=1,
 ):
     self.comp_idx += 1
     if group > 1:
         assert block_type == '3d-group'
     log.info('in: %d out: %d' % (in_filters, out_filters))
     if block_type == '2.5d':
         i = 3 * in_filters * out_filters * kernels[1] * kernels[2]
         i /= in_filters * kernels[1] * kernels[2] + 3 * out_filters
         middle_filters = int(i)
         self.prev_blob = self.model.ConvNd(
             self.prev_blob,
             '%scomp_%d_conv_%d_middle' %
             (self.prefix, self.comp_count, self.comp_idx),
             in_filters,
             middle_filters,
             [1, kernels[1], kernels[2]],
             weight_init=("MSRAFill", {}),
             strides=[1, strides[1], strides[2]] * 1,
             pads=[0, pads[1], pads[2]] * 2,
             no_bias=self.no_bias,
         )
         self.add_spatial_bn(middle_filters, suffix='_middle')
         self.add_relu()
         self.prev_blob = self.model.ConvNd(
             self.prev_blob,
             '%scomp_%d_conv_%d' %
             (self.prefix, self.comp_count, self.comp_idx),
             middle_filters,
             out_filters,
             [kernels[0], 1, 1],
             weight_init=("MSRAFill", {}),
             strides=[strides[0], 1, 1] * 1,
             pads=[pads[0], 0, 0] * 2,
             no_bias=self.no_bias,
         )
     elif block_type == '0.3d' or block_type == '0.3d+relu':
         self.prev_blob = self.model.ConvNd(
             self.prev_blob,
             '%scomp_%d_conv_%d_middle' %
             (self.prefix, self.comp_count, self.comp_idx),
             in_filters,
             out_filters,
             [1, 1, 1],
             weight_init=("MSRAFill", {}),
             strides=[1, 1, 1] * 1,
             pads=[0, 0, 0] * 2,
             no_bias=self.no_bias,
         )
         self.add_spatial_bn(out_filters, suffix='_middle')
         if block_type == '0.3d+relu':
             self.add_relu()
         self.prev_blob = brew.conv_nd(
             self.model,
             self.prev_blob,
             '%scomp_%d_conv_%d' %
             (self.prefix, self.comp_count, self.comp_idx),
             out_filters,
             out_filters,
             weight_init=("MSRAFill", {}),
             kernel=kernels,
             strides=strides,
             pads=pads,
             group=out_filters,
             no_bias=self.no_bias,
             use_cudnn=False,
             engine="CHANNELWISE_3D",
         )
     elif block_type == '3d':
         self.prev_blob = self.model.ConvNd(
             self.prev_blob,
             '%scomp_%d_conv_%d' %
             (self.prefix, self.comp_count, self.comp_idx),
             in_filters,
             out_filters,
             kernels,
             weight_init=("MSRAFill", {}),
             strides=strides,
             pads=pads,
             no_bias=self.no_bias,
         )
     elif block_type == '3d-sep':  # channel_wise 3d conv block
         self.add_channelwise_conv(
             in_filters,
             out_filters,
             kernels,
             strides=strides,
             pads=pads,
         )
     elif block_type == '2.5d-sep':
         assert in_filters == out_filters
         middle_filters = out_filters
         self.add_channelwise_conv(in_filters,
                                   middle_filters,
                                   [1, kernels[1], kernels[2]],
                                   strides=[1, strides[1], strides[2]] * 1,
                                   pads=[0, pads[1], pads[2]] * 2,
                                   suffix='_middle')
         self.add_spatial_bn(middle_filters, suffix='_middle')
         self.add_relu()
         self.prev_blob = self.model.ConvNd(
             self.prev_blob,
             '%scomp_%d_conv_%d_1x' %
             (self.prefix, self.comp_count, self.comp_idx),
             middle_filters,
             middle_filters,
             [1, 1, 1],
             weight_init=("MSRAFill", {}),
             strides=[1, 1, 1] * 1,
             pads=[0, 0, 0] * 2,
             no_bias=self.no_bias,
         )
         self.add_spatial_bn(middle_filters, suffix='_1x')
         self.add_relu()
         self.add_channelwise_conv(middle_filters,
                                   out_filters, [kernels[0], 1, 1],
                                   strides=[strides[0], 1, 1] * 1,
                                   pads=[pads[0], 0, 0] * 2)
     else:
         log.info('Unknown block type!')
     return self.prev_blob