class _Act(Layer): def __init__(self, act='', lrelu_alpha=0.1, **kwargs): super(_Act, self).__init__(**kwargs) if act == 'prelu': self.func = PReLU() else: self.func = LeakyReLU(alpha=lrelu_alpha) def call(self, inputs, **kwargs): return self.func(inputs) def compute_output_shape(self, input_shape): return self.func.compute_output_shape(input_shape)
class Downscale(Layer): def __init__(self, dim, **kwargs): super(Downscale, self).__init__(**kwargs) self.conv_2d = Conv2D(dim, kernel_size=5, strides=2, padding='same') self.act = LeakyReLU(alpha=0.1) def call(self, inputs, **kwargs): x = self.conv_2d(inputs) x = self.act(x) return x def compute_output_shape(self, input_shape): input_shape = self.conv_2d.compute_output_shape(input_shape) input_shape = self.act.compute_output_shape(input_shape) return input_shape