コード例 #1
0
 def EBlock(self,
            last,
            channels,
            resblocks=1,
            bottleneck=True,
            kernel=[1, 3],
            stride=[1, 2],
            format=DATA_FORMAT,
            activation=ACTIVATION,
            normalizer=None,
            regularizer=None,
            collections=None):
     initializer = tf.initializers.variance_scaling(1.0, 'fan_in', 'normal',
                                                    self.random_seed,
                                                    self.dtype)
     skip = last
     in_channels = skip.get_shape()[-3 if format == 'NCHW' else -1]
     if activation: last = activation(last)
     if bottleneck and in_channels > channels:
         last = slim.conv2d(last,
                            channels, [1, 1], [1, 1],
                            'SAME',
                            format,
                            1,
                            activation,
                            None,
                            weights_initializer=initializer,
                            weights_regularizer=regularizer,
                            variables_collections=collections)
     last = slim.conv2d(last,
                        channels,
                        kernel,
                        stride,
                        'SAME',
                        format,
                        1,
                        None,
                        None,
                        weights_initializer=initializer,
                        weights_regularizer=regularizer,
                        variables_collections=collections)
     for i in range(resblocks):
         with tf.variable_scope('ResBlock_{}'.format(i)):
             last = self.ResBlock(last,
                                  channels,
                                  format=format,
                                  activation=activation,
                                  normalizer=normalizer,
                                  regularizer=regularizer,
                                  collections=collections)
     with tf.variable_scope('DenseConnection'):
         last = layers.SEUnit(last, channels, format, collections)
         if stride != 1 or stride != [1, 1]:
             pool_stride = [1, 1] + stride if format == 'NCHW' else [
                 1
             ] + stride + [1]
             skip = tf.nn.avg_pool(skip, pool_stride, pool_stride, 'SAME',
                                   format)
         last = tf.concat([skip, last], -3 if format == 'NCHW' else -1)
     return last
コード例 #2
0
 def ResBlock(self,
              last,
              channels,
              kernel=[1, 3],
              stride=[1, 1],
              biases=True,
              format=DATA_FORMAT,
              dilate=1,
              activation=ACTIVATION,
              normalizer=None,
              regularizer=None,
              collections=None):
     biases = tf.initializers.zeros(self.dtype) if biases else None
     initializer = tf.initializers.variance_scaling(1.0, 'fan_in', 'normal',
                                                    self.random_seed,
                                                    self.dtype)
     skip = last
     # pre-activation
     if normalizer: last = normalizer(last)
     if activation: last = activation(last)
     # convolution
     last = slim.conv2d(last,
                        channels,
                        kernel,
                        stride,
                        'SAME',
                        format,
                        dilate,
                        activation,
                        normalizer,
                        None,
                        initializer,
                        regularizer,
                        biases,
                        variables_collections=collections)
     last = slim.conv2d(last,
                        channels,
                        kernel,
                        stride,
                        'SAME',
                        format,
                        dilate,
                        None,
                        None,
                        None,
                        initializer,
                        regularizer,
                        biases,
                        variables_collections=collections)
     # residual connection
     last = layers.SEUnit(last, channels, format, collections)
     last += skip
     return last
コード例 #3
0
ファイル: network.py プロジェクト: KotoriCANOE/StarGAN
 def EBlock(self,
            last,
            channels,
            resblocks=1,
            kernel=[4, 4],
            stride=[2, 2],
            format=DATA_FORMAT,
            activation=ACTIVATION,
            normalizer=None,
            regularizer=None,
            collections=None):
     initializer = tf.initializers.variance_scaling(1.0, 'fan_in', 'normal',
                                                    self.random_seed,
                                                    self.dtype)
     skip = last
     # pre-activation
     if activation: last = activation(last)
     # convolution
     last = slim.conv2d(last,
                        channels,
                        kernel,
                        stride,
                        'SAME',
                        format,
                        1,
                        None,
                        None,
                        weights_initializer=initializer,
                        weights_regularizer=regularizer,
                        variables_collections=collections)
     # residual blocks
     for i in range(resblocks):
         with tf.variable_scope('ResBlock_{}'.format(i)):
             last = self.ResBlock(last,
                                  channels,
                                  format=format,
                                  activation=activation,
                                  normalizer=normalizer,
                                  regularizer=regularizer,
                                  collections=collections)
     # dense connection
     with tf.variable_scope('DenseConnection'):
         last = layers.SEUnit(last, channels, format, collections)
         if stride != 1 or stride != [1, 1]:
             pool_stride = [1, 1] + stride if format == 'NCHW' else [
                 1
             ] + stride + [1]
             skip = tf.nn.avg_pool(skip, pool_stride, pool_stride, 'SAME',
                                   format)
         last = tf.concat([skip, last], -3 if format == 'NCHW' else -1)
     return last