Ejemplo n.º 1
0
class Conv1DTranspose(Layer):
    def __init__(self, filters, kernel_size, strides=1, *args, **kwargs):
        self._filters = filters
        self._kernel_size = (1, kernel_size)
        self._strides = (1, strides)
        self._args, self._kwargs = args, kwargs
        super(Conv1DTranspose, self).__init__()

    def build(self, input_shape):
        print("build", input_shape)
        self._model = Sequential()
        self._model.add(
            Lambda(lambda x: K.expand_dims(x, axis=1),
                   batch_input_shape=input_shape))
        self._model.add(
            Conv2DTranspose(self._filters,
                            kernel_size=self._kernel_size,
                            strides=self._strides,
                            *self._args,
                            **self._kwargs))
        self._model.add(Lambda(lambda x: x[:, 0]))
        self._model.summary()
        super(Conv1DTranspose, self).build(input_shape)

    def call(self, x):
        return self._model(x)

    def compute_output_shape(self, input_shape):
        return self._model.compute_output_shape(input_shape)
Ejemplo n.º 2
0
class UpConvBlock(layers.Layer):
    def __init__(self, filters, up_scale, **kwargs):
        super().__init__(**kwargs)
        self.input_spec = layers.InputSpec(ndim=4)
        self.filters = filters
        self.up_scale = up_scale
        self.constant_filters = 16

    @shape_type_conversion
    def build(self, input_shape):
        total_up_scale = 2**self.up_scale
        trunc_init0 = initializers.TruncatedNormal()
        trunc_init1 = initializers.TruncatedNormal(stddev=0.1)

        self.features = Sequential()
        for i in range(self.up_scale):
            is_last = i == self.up_scale - 1
            out_features = self.filters if is_last else self.constant_filters
            kernel_init0 = trunc_init0 if is_last else 'glorot_uniform'
            kernel_init1 = trunc_init1 if is_last else 'glorot_uniform'

            self.features.add(
                SameConv(filters=out_features,
                         kernel_size=1,
                         strides=1,
                         activation='relu',
                         kernel_initializer=kernel_init0,
                         kernel_regularizer=regularizers.l2(1e-3)))
            self.features.add(
                layers.Conv2DTranspose(
                    out_features,
                    kernel_size=total_up_scale,
                    strides=2,
                    padding='same',
                    kernel_initializer=kernel_init1,
                    kernel_regularizer=regularizers.l2(1e-3)))

        super().build(input_shape)

    def call(self, inputs, **kwargs):
        return self.features(inputs)

    @shape_type_conversion
    def compute_output_shape(self, input_shape):
        return self.features.compute_output_shape(input_shape)

    def get_config(self):
        config = super().get_config()
        config.update({'filters': self.filters, 'up_scale': self.up_scale})

        return config
Ejemplo n.º 3
0
class DoubleConvBlock(layers.Layer):
    def __init__(self,
                 mid_features,
                 out_features=None,
                 stride=1,
                 activation='relu',
                 **kwargs):
        super().__init__(**kwargs)
        self.input_spec = layers.InputSpec(ndim=4)
        self.mid_features = mid_features
        self.out_features = out_features
        self._out_features = self.out_features or self.mid_features
        self.stride = stride
        self.activation = activations.get(activation)

    @shape_type_conversion
    def build(self, input_shape):
        self.features = Sequential([
            ConvNormRelu(self.mid_features,
                         3,
                         strides=self.stride,
                         kernel_regularizer=regularizers.l2(1e-3)),
            ConvNormRelu(self._out_features,
                         3,
                         activation=self.activation,
                         kernel_regularizer=regularizers.l2(1e-3)),
        ])

        super().build(input_shape)

    def call(self, inputs, **kwargs):
        return self.features(inputs)

    @shape_type_conversion
    def compute_output_shape(self, input_shape):
        return self.features.compute_output_shape(input_shape)

    def get_config(self):
        config = super().get_config()
        config.update({
            'mid_features': self.mid_features,
            'out_features': self.out_features,
            'stride': self.stride,
            'activation': activations.serialize(self.activation)
        })

        return config