class Conv1DTranspose(Layer): def __init__(self, filters, kernel_size, strides=1, *args, **kwargs): self._filters = filters self._kernel_size = (1, kernel_size) self._strides = (1, strides) self._args, self._kwargs = args, kwargs super(Conv1DTranspose, self).__init__() def build(self, input_shape): print("build", input_shape) self._model = Sequential() self._model.add( Lambda(lambda x: K.expand_dims(x, axis=1), batch_input_shape=input_shape)) self._model.add( Conv2DTranspose(self._filters, kernel_size=self._kernel_size, strides=self._strides, *self._args, **self._kwargs)) self._model.add(Lambda(lambda x: x[:, 0])) self._model.summary() super(Conv1DTranspose, self).build(input_shape) def call(self, x): return self._model(x) def compute_output_shape(self, input_shape): return self._model.compute_output_shape(input_shape)
class UpConvBlock(layers.Layer): def __init__(self, filters, up_scale, **kwargs): super().__init__(**kwargs) self.input_spec = layers.InputSpec(ndim=4) self.filters = filters self.up_scale = up_scale self.constant_filters = 16 @shape_type_conversion def build(self, input_shape): total_up_scale = 2**self.up_scale trunc_init0 = initializers.TruncatedNormal() trunc_init1 = initializers.TruncatedNormal(stddev=0.1) self.features = Sequential() for i in range(self.up_scale): is_last = i == self.up_scale - 1 out_features = self.filters if is_last else self.constant_filters kernel_init0 = trunc_init0 if is_last else 'glorot_uniform' kernel_init1 = trunc_init1 if is_last else 'glorot_uniform' self.features.add( SameConv(filters=out_features, kernel_size=1, strides=1, activation='relu', kernel_initializer=kernel_init0, kernel_regularizer=regularizers.l2(1e-3))) self.features.add( layers.Conv2DTranspose( out_features, kernel_size=total_up_scale, strides=2, padding='same', kernel_initializer=kernel_init1, kernel_regularizer=regularizers.l2(1e-3))) super().build(input_shape) def call(self, inputs, **kwargs): return self.features(inputs) @shape_type_conversion def compute_output_shape(self, input_shape): return self.features.compute_output_shape(input_shape) def get_config(self): config = super().get_config() config.update({'filters': self.filters, 'up_scale': self.up_scale}) return config
class DoubleConvBlock(layers.Layer): def __init__(self, mid_features, out_features=None, stride=1, activation='relu', **kwargs): super().__init__(**kwargs) self.input_spec = layers.InputSpec(ndim=4) self.mid_features = mid_features self.out_features = out_features self._out_features = self.out_features or self.mid_features self.stride = stride self.activation = activations.get(activation) @shape_type_conversion def build(self, input_shape): self.features = Sequential([ ConvNormRelu(self.mid_features, 3, strides=self.stride, kernel_regularizer=regularizers.l2(1e-3)), ConvNormRelu(self._out_features, 3, activation=self.activation, kernel_regularizer=regularizers.l2(1e-3)), ]) super().build(input_shape) def call(self, inputs, **kwargs): return self.features(inputs) @shape_type_conversion def compute_output_shape(self, input_shape): return self.features.compute_output_shape(input_shape) def get_config(self): config = super().get_config() config.update({ 'mid_features': self.mid_features, 'out_features': self.out_features, 'stride': self.stride, 'activation': activations.serialize(self.activation) }) return config