class ReconstructionLoss(Layer): # TODO: subclass `Loss` instead def __init__(self, mean=True): self._mean = mean super().__init__() def _mse(self, x): return tf.reduce_sum(tf.losses.mse(x[0], x[1]), axis=[1, 2]) def build(self, input_shape): self._model = Sequential([Lambda(self._mse)]) if self._mean: self._model.add(Lambda(tf.reduce_mean)) super().build(input_shape) def call(self, x, **kwargs): return self._model(x) def compute_output_shape(self, input_shape): return self._model.compute_output_shape(input_shape) def get_config(self): config = super().get_config().copy() config.update({ "mean": self._mean, }) return config
class UpSample2D(Layer): def __init__(self, filters, pool_kernel_size): self._filters = filters self._pool_kernel_size = pool_kernel_size super().__init__() def build(self, input_shape): self._model = Sequential([ Conv2DTranspose(self._filters, 1, strides=self._pool_kernel_size), BatchNormalization(), ReLU(), ], ) super().build(input_shape) def call(self, x, **kwargs): return self._model(x) def compute_output_shape(self, input_shape): return self._model.compute_output_shape(input_shape) def get_config(self): config = super().get_config().copy() config.update({ "filters": self._filters, "pool_kernel_size": self._pool_kernel_size, }) return config
class Conv1DTranspose(Layer): def __init__(self, filters, kernel_size, strides=1, *args, **kwargs): self._filters = filters self._kernel_size = (1, kernel_size) self._strides = (1, strides) self._args, self._kwargs = args, kwargs super(Conv1DTranspose, self).__init__() def build(self, input_shape): self._model = Sequential() self._model.add( Lambda(lambda x: K.expand_dims(x, axis=1), batch_input_shape=input_shape)) self._model.add( Conv2DTranspose(self._filters, kernel_size=self._kernel_size, strides=self._strides, *self._args, **self._kwargs)) self._model.add(Lambda(lambda x: x[:, 0])) self._model.summary() super(Conv1DTranspose, self).build(input_shape) def call(self, x): return self._model(x) def compute_output_shape(self, input_shape): return self._model.compute_output_shape(input_shape)
class Conv1DTranspose(Layer): """ A 1D transposed convolutional layer. """ def __init__(self, filters, kernel_size, strides=1, *args, **kwargs): super().__init__() self._filters = filters self._kernel_size = (1, kernel_size) self._strides = (1, strides) self._args, self._kwargs = args, kwargs self._model = Sequential() def build(self, input_shape): """ Builds the layer. :param input_shape: The input tensor shape. """ self._model.add( Lambda(lambda x: backend.expand_dims(x, axis=1), batch_input_shape=input_shape)) self._model.add( Conv2DTranspose(self._filters, kernel_size=self._kernel_size, strides=self._strides, *self._args, **self._kwargs)) self._model.add(Lambda(lambda x: x[:, 0])) super().build(input_shape) def call(self, x, training=False, mask=None): """ The forward pass of the layer. :param x: The input tensor. :param training: A boolean specifying if the layer should be in training mode. :param mask: A mask for the input tensor. :return: The output tensor of the layer. """ return self._model(x) def compute_output_shape(self, input_shape): """ The output shape of the layer. :param input_shape: :return: """ return self._model.compute_output_shape(input_shape)
class NormConv2D(Layer): """ Batch-normalized, ReLU-activated convolution or transpose convolution """ def __init__(self, filters, kernel_size, transpose=False): self._kernel_size = kernel_size self._filters = filters self._transpose = transpose super().__init__() def build(self, input_shape): conv_layer = Conv2DTranspose if self._transpose else Conv2D # TODO: Try with VALID padding self._model = Sequential([ conv_layer( self._filters, self._kernel_size, padding="SAME", ), BatchNormalization(), ReLU(), ]) super().build(input_shape) def call(self, x, **kwargs): return self._model(x) def compute_output_shape(self, input_shape): return self._model.compute_output_shape(input_shape) def get_config(self): config = super().get_config().copy() config.update({ "filters": self._filters, "kernel_size": self._kernel_size, "transpose": self._transpose, }) return config
class ConvBlock2D(Layer): def __init__(self, filters, repeat, use_inception=True, transpose=False): self._filters = filters self._repeat = repeat self._use_inception = use_inception self._transpose = transpose super().__init__() def build(self, input_shape): if self._use_inception: layers = [ Inception2D(self._filters, transpose=self._transpose) for i in range(self._repeat) ] else: layers = [ NormConv2D(self._filters, 3, transpose=self._transpose) for i in range(self._repeat) ] self._model = Sequential(layers) super().build(input_shape) def call(self, x, **kwargs): return self._model(x) def compute_output_shape(self, input_shape): return self._model.compute_output_shape(input_shape) def get_config(self): config = super().get_config().copy() config.update({ "filters": self._filters, "repeat": self._repeat, "use_inception": self._use_inception, "transpose": self._transpose, }) return config
class KLLoss(Layer): # TODO: subclass `Loss` instead def __init__(self, mean=True): self._mean = mean super().__init__() def _log_normal_pdf(self, sample, mean, logvar, raxis=1): log2pi = tf.math.log(2.0 * np.pi) return tf.reduce_sum( -0.5 * ((sample - mean)**2.0 * tf.exp(-logvar) + logvar + log2pi), axis=raxis, ) def _kld(self, x): return self._log_normal_pdf(x[0], x[1], x[2]) - self._log_normal_pdf( x[0], 0.0, 0.0) def build(self, input_shape): self._model = Sequential([Lambda(self._kld)]) if self._mean: self._model.add(Lambda(tf.reduce_mean)) super().build(input_shape) def call(self, x, **kwargs): z, z_mean, z_log_var = x return self._model([z, z_mean, z_log_var]) def compute_output_shape(self, input_shape): return self._model.compute_output_shape(input_shape) def get_config(self): config = super().get_config().copy() config.update({ "mean": self._mean, }) return config