class Inception2D(Layer): def __init__(self, filters, transpose=False): self._filters = filters self._transpose = transpose super().__init__() def build(self, input_shape): filters = self._filters inputs = Input(shape=input_shape[1:]) bottleneck = NormConv2D(filters, 1, transpose=self._transpose)(inputs) conv1 = NormConv2D(filters, 1, transpose=self._transpose)(bottleneck) conv3 = NormConv2D(filters, 3, transpose=self._transpose)(bottleneck) conv5 = NormConv2D(filters, 5, transpose=self._transpose)(bottleneck) conv7 = NormConv2D(filters, 7, transpose=self._transpose)(bottleneck) pool3 = MaxPool2D(pool_size=3, strides=1, padding="SAME")(inputs) pool5 = MaxPool2D(pool_size=5, strides=1, padding="SAME")(inputs) merged = Add()([conv1, conv3, conv5, conv7, pool3, pool5]) self._model = Model(inputs=inputs, outputs=merged) super().build(input_shape) def call(self, x, **kwargs): return self._model(x) def compute_output_shape(self, input_shape): return self._model.compute_output_shape(input_shape) def get_config(self): config = super().get_config().copy() config.update({ "filters": self._filters, "transpose": self._transpose, }) return config
class Reparameterize(Layer): def build(self, input_shape): print("Reparam Layer input shape:", input_shape) inputs = Input(shape=[x[1:] for x in input_shape]) print("Reparam input shape:", inputs.shape) # TODO: Get rid of lambda expressions inside Lambda layers epsilon = Lambda(lambda x: tf.keras.backend.random_normal( shape=tf.shape(x[0])))(inputs) print("Epsilon shape:", epsilon.shape) mean = Lambda(lambda x: x[0])(inputs) print("Mean shape:", mean.shape) var = Lambda(lambda x: tf.exp(x[1] * 0.5))(inputs) print("Var shape:", var.shape) reparam = Multiply()([epsilon, var]) print("Mul shape:", reparam.shape) reparam = Add()([reparam, mean]) print("Add shape:", reparam.shape) self._model = Model(inputs=inputs, outputs=reparam) print("Reparam output shape:", self._model.output_shape) super().build(input_shape) def call(self, x): z_mean, z_log_var = x output = self._model([z_mean, z_log_var]) print("Call time shape:", output.shape) return output def compute_output_shape(self, input_shape): return self._model.compute_output_shape(input_shape)