Beispiel #1
0
    def encoder(self):
        """ Encoder Network """
        kwargs = dict(kernel_initializer=self.kernel_initializer)
        input_ = Input(shape=self.input_shape)
        in_conv_filters = self.input_shape[0]
        if self.input_shape[0] > 128:
            in_conv_filters = 128 + (self.input_shape[0] - 128) // 4
        dense_shape = self.input_shape[0] // 16

        var_x = Conv2DBlock(in_conv_filters, activation=None, **kwargs)(input_)
        tmp_x = var_x

        var_x = LeakyReLU(alpha=0.2)(var_x)
        res_cycles = 8 if self.config.get("lowmem", False) else 16
        for _ in range(res_cycles):
            nn_x = ResidualBlock(in_conv_filters, **kwargs)(var_x)
            var_x = nn_x
        # consider adding scale before this layer to scale the residual chain
        tmp_x = LeakyReLU(alpha=0.1)(tmp_x)
        var_x = add([var_x, tmp_x])
        var_x = Conv2DBlock(128, activation="leakyrelu", **kwargs)(var_x)
        var_x = PixelShuffler()(var_x)
        var_x = Conv2DBlock(128, activation="leakyrelu", **kwargs)(var_x)
        var_x = PixelShuffler()(var_x)
        var_x = Conv2DBlock(128, activation="leakyrelu", **kwargs)(var_x)
        var_x = SeparableConv2DBlock(256, **kwargs)(var_x)
        var_x = Conv2DBlock(512, activation="leakyrelu", **kwargs)(var_x)
        if not self.config.get("lowmem", False):
            var_x = SeparableConv2DBlock(1024, **kwargs)(var_x)

        var_x = Dense(self.encoder_dim, **kwargs)(Flatten()(var_x))
        var_x = Dense(dense_shape * dense_shape * 1024, **kwargs)(var_x)
        var_x = Reshape((dense_shape, dense_shape, 1024))(var_x)
        var_x = UpscaleBlock(512, activation="leakyrelu", **kwargs)(var_x)
        return KerasModel(input_, var_x, name="encoder")
Beispiel #2
0
    def encoder(self):
        """ Encoder Network """
        kwargs = dict(kernel_initializer=self.kernel_initializer)
        input_ = Input(shape=self.input_shape)
        in_conv_filters = self.input_shape[0]
        if self.input_shape[0] > 128:
            in_conv_filters = 128 + (self.input_shape[0] - 128) // 4
        dense_shape = self.input_shape[0] // 16

        var_x = self.blocks.conv(input_, in_conv_filters, res_block_follows=True, **kwargs)
        tmp_x = var_x
        res_cycles = 8 if self.config.get("lowmem", False) else 16
        for _ in range(res_cycles):
            nn_x = self.blocks.res_block(var_x, 128, **kwargs)
            var_x = nn_x
        # consider adding scale before this layer to scale the residual chain
        var_x = add([var_x, tmp_x])
        var_x = self.blocks.conv(var_x, 128, **kwargs)
        var_x = PixelShuffler()(var_x)
        var_x = self.blocks.conv(var_x, 128, **kwargs)
        var_x = PixelShuffler()(var_x)
        var_x = self.blocks.conv(var_x, 128, **kwargs)
        var_x = self.blocks.conv_sep(var_x, 256, **kwargs)
        var_x = self.blocks.conv(var_x, 512, **kwargs)
        if not self.config.get("lowmem", False):
            var_x = self.blocks.conv_sep(var_x, 1024, **kwargs)

        var_x = Dense(self.encoder_dim, **kwargs)(Flatten()(var_x))
        var_x = Dense(dense_shape * dense_shape * 1024, **kwargs)(var_x)
        var_x = Reshape((dense_shape, dense_shape, 1024))(var_x)
        var_x = self.blocks.upscale(var_x, 512, **kwargs)
        return KerasModel(input_, var_x)