Example #1
0
class Inception2D(Layer):
    def __init__(self, filters, transpose=False):
        self._filters = filters
        self._transpose = transpose
        super().__init__()

    def build(self, input_shape):
        filters = self._filters
        inputs = Input(shape=input_shape[1:])
        bottleneck = NormConv2D(filters, 1, transpose=self._transpose)(inputs)

        conv1 = NormConv2D(filters, 1, transpose=self._transpose)(bottleneck)
        conv3 = NormConv2D(filters, 3, transpose=self._transpose)(bottleneck)
        conv5 = NormConv2D(filters, 5, transpose=self._transpose)(bottleneck)
        conv7 = NormConv2D(filters, 7, transpose=self._transpose)(bottleneck)
        pool3 = MaxPool2D(pool_size=3, strides=1, padding="SAME")(inputs)
        pool5 = MaxPool2D(pool_size=5, strides=1, padding="SAME")(inputs)
        merged = Add()([conv1, conv3, conv5, conv7, pool3, pool5])
        self._model = Model(inputs=inputs, outputs=merged)
        super().build(input_shape)

    def call(self, x, **kwargs):
        return self._model(x)

    def compute_output_shape(self, input_shape):
        return self._model.compute_output_shape(input_shape)

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            "filters": self._filters,
            "transpose": self._transpose,
        })
        return config
Example #2
0
class Reparameterize(Layer):
    def build(self, input_shape):
        print("Reparam Layer input shape:", input_shape)
        inputs = Input(shape=[x[1:] for x in input_shape])
        print("Reparam input shape:", inputs.shape)
        # TODO: Get rid of lambda expressions inside Lambda layers
        epsilon = Lambda(lambda x: tf.keras.backend.random_normal(
            shape=tf.shape(x[0])))(inputs)
        print("Epsilon shape:", epsilon.shape)
        mean = Lambda(lambda x: x[0])(inputs)
        print("Mean shape:", mean.shape)
        var = Lambda(lambda x: tf.exp(x[1] * 0.5))(inputs)
        print("Var shape:", var.shape)
        reparam = Multiply()([epsilon, var])
        print("Mul shape:", reparam.shape)
        reparam = Add()([reparam, mean])
        print("Add shape:", reparam.shape)
        self._model = Model(inputs=inputs, outputs=reparam)
        print("Reparam output shape:", self._model.output_shape)
        super().build(input_shape)

    def call(self, x):
        z_mean, z_log_var = x
        output = self._model([z_mean, z_log_var])
        print("Call time shape:", output.shape)
        return output

    def compute_output_shape(self, input_shape):
        return self._model.compute_output_shape(input_shape)