예제 #1
0
 def v(depth, nPlanes):
     m = scn.Sequential()
     if depth == 1:
         for _ in range(reps):
             res(m, nPlanes, nPlanes, dropout_p)
     else:
         m = scn.Sequential()
         for _ in range(reps):
             res(m, nPlanes, nPlanes, dropout_p)
         if dropout_width:
             m.add(scn.ConcatTable().add(scn.Identity()).add(
                 scn.Sequential().add(scn.BatchNormReLU(nPlanes)).add(
                     #In place of Maxpooling
                     scn.Convolution(
                         dimension, nPlanes, nPlanes, 2, 2,
                         False)).add(scn.Dropout(dropout_p)).add(
                             v(depth - 1, nPlanes)).add(
                                 scn.BatchNormReLU(nPlanes)).add(
                                     scn.Deconvolution(
                                         dimension, nPlanes, nPlanes, 2, 2,
                                         False))))
         else:
             m.add(scn.ConcatTable().add(scn.Identity()).add(
                 scn.Sequential().add(scn.BatchNormReLU(nPlanes)).add(
                     scn.Convolution(dimension, nPlanes, nPlanes, 2, 2,
                                     False)).add(v(depth - 1, nPlanes)).add(
                                         scn.BatchNormReLU(nPlanes)).add(
                                             scn.Deconvolution(
                                                 dimension, nPlanes,
                                                 nPlanes, 2, 2, False))))
         m.add(scn.JoinTable())
         for i in range(reps):
             res(m, 2 * nPlanes if i == 0 else nPlanes, nPlanes, dropout_p)
     return m
예제 #2
0
    def __init__(self, cfg, name='yresnet_encoder'):
        super(YResNetEncoder, self).__init__(cfg, name='network_base')
        self.model_config = cfg[name]

        # YResNet Configurations
        # Conv block repetition factor
        self.reps = self.model_config.get('reps', 2)
        self.kernel_size = self.model_config.get('kernel_size', 2)
        self.num_strides = self.model_config.get('num_strides', 5)
        self.num_filters = self.model_config.get('filters', 16)
        self.nPlanes = [
            i * self.num_filters for i in range(1, self.num_strides + 1)
        ]
        # [filter size, filter stride]
        self.downsample = [self.kernel_size, 2]

        dropout_prob = self.model_config.get('dropout_prob', 0.5)

        # Define Sparse YResNet Encoder
        self.encoding_block = scn.Sequential()
        self.encoding_conv = scn.Sequential()
        for i in range(self.num_strides):
            m = scn.Sequential()
            for _ in range(self.reps):
                self._resnet_block(m, self.nPlanes[i], self.nPlanes[i])
            self.encoding_block.add(m)
            m = scn.Sequential()
            if i < self.num_strides - 1:
                m.add(
                    scn.BatchNormLeakyReLU(self.nPlanes[i], leakiness=self.leakiness)).add(
                    scn.Convolution(self.dimension, self.nPlanes[i], self.nPlanes[i+1], \
                        self.downsample[0], self.downsample[1], self.allow_bias)).add(
                    scn.Dropout(p=dropout_prob))
            self.encoding_conv.add(m)
예제 #3
0
    def __init__(self, cfg, name='yresnet_decoder'):
        super(YResNetDecoder, self).__init__(cfg, name='network_base')
        self.model_config = cfg[name]

        self.reps = self.model_config.get('reps',
                                          2)  # Conv block repetition factor
        self.kernel_size = self.model_config.get('kernel_size', 2)
        self.num_strides = self.model_config.get('num_strides', 5)
        self.num_filters = self.model_config.get('filters', 16)
        self.nPlanes = [
            i * self.num_filters for i in range(1, self.num_strides + 1)
        ]
        self.downsample = [self.kernel_size, 2]  # [filter size, filter stride]
        self.concat = scn.JoinTable()
        self.add = scn.AddTable()
        dropout_prob = self.model_config.get('dropout_prob', 0.5)

        self.encoder_num_filters = self.model_config.get(
            'encoder_num_filters', None)
        if self.encoder_num_filters is None:
            self.encoder_num_filters = self.num_filters
        self.encoder_nPlanes = [
            i * self.encoder_num_filters
            for i in range(1, self.num_strides + 1)
        ]

        # Define Sparse YResNet Decoder.
        self.decoding_block = scn.Sequential()
        self.decoding_conv = scn.Sequential()
        for idx, i in enumerate(list(range(self.num_strides - 2, -1, -1))):
            if idx == 0:
                m = scn.Sequential().add(
                    scn.BatchNormLeakyReLU(self.encoder_nPlanes[i + 1],
                                           leakiness=self.leakiness)).add(
                                               scn.Deconvolution(
                                                   self.dimension,
                                                   self.encoder_nPlanes[i + 1],
                                                   self.nPlanes[i],
                                                   self.downsample[0],
                                                   self.downsample[1],
                                                   self.allow_bias))
            else:
                m = scn.Sequential().add(
                    scn.BatchNormLeakyReLU(
                        self.nPlanes[i + 1], leakiness=self.leakiness)).add(
                            scn.Deconvolution(self.dimension,
                                              self.nPlanes[i + 1],
                                              self.nPlanes[i],
                                              self.downsample[0],
                                              self.downsample[1],
                                              self.allow_bias)).add(
                                                  scn.Dropout(p=dropout_prob))
            self.decoding_conv.add(m)
            m = scn.Sequential()
            for j in range(self.reps):
                self._resnet_block(m, self.nPlanes[i] + (self.encoder_nPlanes[i] \
                    if j == 0 else 0), self.nPlanes[i])
            self.decoding_block.add(m)
예제 #4
0
 def _res_dropout(m, a, b, p):
     m.add(scn.ConcatTable()
           .add(scn.Identity() if a == b else scn.NetworkInNetwork(a, b, False))
           .add(scn.Sequential()
                .add(scn.BatchNormReLU(a))
                .add(Dropout(p))
                .add(scn.SubmanifoldConvolution(dimension, a, b, 3, False))
                .add(scn.BatchNormReLU(b))
                .add(scn.Dropout(p))
                .add(scn.SubmanifoldConvolution(dimension, b, b, 3, False))))\
      .add(scn.AddTable())